diff --git a/.flake8 b/.flake8
index 9f54c46ee..982c229d4 100644
--- a/.flake8
+++ b/.flake8
@@ -17,6 +17,7 @@ exclude =
.git,
**/data/**,
icefall/shared/make_kn_lm.py,
+ egs/librispeech/ASR/transducer_emformer/train.py,
icefall/__init__.py
ignore =
diff --git a/egs/librispeech/ASR/transducer_emformer/client/index.html b/egs/librispeech/ASR/transducer_emformer/client/index.html
index 85a21df49..5b6baa001 100644
--- a/egs/librispeech/ASR/transducer_emformer/client/index.html
+++ b/egs/librispeech/ASR/transducer_emformer/client/index.html
@@ -2,35 +2,50 @@
-
-
+
+
+ crossorigin="anonymous">
+
Hello next-gen Kaldi
-
+
Hello next-gen Kaldi
+
+
+
+
+
+
+
+
+
+ crossorigin="anonymous">
+
+ crossorigin="anonymous">
+
+ crossorigin="anonymous">
+
+
+
diff --git a/egs/librispeech/ASR/transducer_emformer/client/main.js b/egs/librispeech/ASR/transducer_emformer/client/main.js
new file mode 100644
index 000000000..a25eb5330
--- /dev/null
+++ b/egs/librispeech/ASR/transducer_emformer/client/main.js
@@ -0,0 +1,60 @@
+/**
+References
+https://developer.mozilla.org/en-US/docs/Web/API/FileList
+https://developer.mozilla.org/en-US/docs/Web/API/FileReader
+https://javascript.info/arraybuffer-binary-arrays
+https://developer.mozilla.org/zh-CN/docs/Web/API/WebSocket
+https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/send
+*/
+
+var socket;
+function initWebSocket() {
+ socket = new WebSocket("ws://localhost:6008/");
+
+ // Connection opened
+ socket.addEventListener(
+ 'open',
+ function(event) { document.getElementById('file').disabled = false; });
+
+ // Connection closed
+ socket.addEventListener('close', function(event) {
+ document.getElementById('file').disabled = true;
+ initWebSocket();
+ });
+
+ // Listen for messages
+ socket.addEventListener('message', function(event) {
+ document.getElementById('results').innerHTML = event.data;
+ console.log('Received message: ', event.data);
+ });
+}
+
+function onFileChange() {
+ var files = document.getElementById("file").files;
+
+ if (files.length == 0) {
+ console.log('No file selected');
+ return;
+ }
+
+ console.log('files: ' + files);
+
+ const file = files[0];
+ console.log(file);
+ console.log('file.name ' + file.name);
+ console.log('file.type ' + file.type);
+ console.log('file.size ' + file.size);
+
+ let reader = new FileReader();
+ reader.onload = function() {
+ let view = new Int16Array(reader.result);
+ console.log('bytes: ' + view.byteLength);
+ // we assume the input file is a wav file.
+ // TODO: add some checks here.
+ let body = view.subarray(44);
+ socket.send(body);
+ socket.send(JSON.stringify({'eof' : 1}));
+ };
+
+ reader.readAsArrayBuffer(file);
+}
diff --git a/egs/librispeech/ASR/transducer_emformer/server.py b/egs/librispeech/ASR/transducer_emformer/server.py
new file mode 100755
index 000000000..35f66f60f
--- /dev/null
+++ b/egs/librispeech/ASR/transducer_emformer/server.py
@@ -0,0 +1,182 @@
+#!/usr/bin/env python3
+import asyncio
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+import websockets
+from streaming_decode import StreamList, get_parser, process_features
+from train import get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import setup_logger
+
+g_params = None
+g_model = None
+g_sp = None
+
+
+def build_stream_list():
+ batch_size = 1 # will change it later
+
+ stream_list = StreamList(
+ batch_size=batch_size,
+ context_size=g_params.context_size,
+ decoding_method=g_params.decoding_method,
+ )
+ return stream_list
+
+
+async def echo(websocket):
+ logging.info(f"connected: {websocket.remote_address}")
+
+ stream_list = build_stream_list()
+
+ # number of frames before subsampling
+ segment_length = g_model.encoder.segment_length
+
+ right_context_length = g_model.encoder.right_context_length
+
+ # We add 3 here since the subsampling method is using
+ # ((len - 1) // 2 - 1) // 2)
+ chunk_length = (segment_length + 3) + right_context_length
+
+ async for message in websocket:
+ if isinstance(message, bytes):
+ samples = torch.frombuffer(message, dtype=torch.int16)
+ samples = samples.to(torch.float32) / 32768
+ stream_list.accept_waveform(
+ audio_samples=[samples],
+ sampling_rate=g_params.sampling_rate,
+ )
+
+ while True:
+ features, active_streams = stream_list.build_batch(
+ chunk_length=chunk_length,
+ segment_length=segment_length,
+ )
+
+ if features is not None:
+ process_features(
+ model=g_model,
+ features=features,
+ streams=active_streams,
+ params=g_params,
+ sp=g_sp,
+ )
+ results = []
+ for stream in stream_list.streams:
+ text = g_sp.decode(stream.decoding_result())
+ results.append(text)
+ await websocket.send(results[0])
+ else:
+ break
+ elif isinstance(message, str):
+ stream_list[0].input_finished()
+ while True:
+ features, active_streams = stream_list.build_batch(
+ chunk_length=chunk_length,
+ segment_length=segment_length,
+ )
+
+ if features is not None:
+ process_features(
+ model=g_model,
+ features=features,
+ streams=active_streams,
+ params=g_params,
+ sp=g_sp,
+ )
+ else:
+ break
+
+ results = []
+ for stream in stream_list.streams:
+ text = g_sp.decode(stream.decoding_result())
+ results.append(text)
+
+ await websocket.send(results[0])
+ await websocket.close()
+
+ logging.info(f"Closed: {websocket.remote_address}")
+
+
+async def loop():
+ logging.info("started")
+ async with websockets.serve(echo, "", 6008):
+ await asyncio.Future() # run forever
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ # Note: params.decoding_method is currently not used.
+ params.res_dir = params.exp_dir / "streaming" / params.decoding_method
+
+ setup_logger(f"{params.res_dir}/log-streaming-decode")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ params.device = device
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if params.avg_last_n > 0:
+ filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ global g_params, g_model, g_sp
+ g_params = params
+ g_model = model
+ g_sp = sp
+
+ asyncio.run(loop())
+
+
+if __name__ == "__main__":
+ torch.manual_seed(20220506)
+ main()
diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py
index 8ebfbb210..2064bd344 100755
--- a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py
+++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py
@@ -233,6 +233,9 @@ class StreamList(object):
for _ in range(batch_size)
]
+ def __getitem__(self, i) -> FeatureExtractionStream:
+ return self.streams[i]
+
@property
def done(self) -> bool:
"""Return True if all streams have reached end of utterance.
@@ -667,8 +670,9 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
- # is defined in local/train_bpe_model.py
+ # and are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
params.vocab_size = sp.get_piece_size()
params.device = device
diff --git a/egs/librispeech/ASR/transducer_emformer/train.py b/egs/librispeech/ASR/transducer_emformer/train.py
index 9798fe5e6..dae30f91b 100755
--- a/egs/librispeech/ASR/transducer_emformer/train.py
+++ b/egs/librispeech/ASR/transducer_emformer/train.py
@@ -378,6 +378,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
+ unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
@@ -811,8 +812,9 @@ def run(rank, world_size, args):
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
- # is defined in local/train_bpe_model.py
+ # and are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
params.vocab_size = sp.get_piece_size()
logging.info(params)