mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
First attempt to add WEB interface for emformer model.
This commit is contained in:
parent
52f19df07d
commit
3fd4f1d13f
1
.flake8
1
.flake8
@ -17,6 +17,7 @@ exclude =
|
||||
.git,
|
||||
**/data/**,
|
||||
icefall/shared/make_kn_lm.py,
|
||||
egs/librispeech/ASR/transducer_emformer/train.py,
|
||||
icefall/__init__.py
|
||||
|
||||
ignore =
|
||||
|
@ -2,35 +2,50 @@
|
||||
<html lang="en">
|
||||
<head>
|
||||
<!-- Required meta tags -->
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
|
||||
<meta charset="utf-8"></meta>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no"></meta>
|
||||
|
||||
<!-- Bootstrap CSS -->
|
||||
<link rel="stylesheet"
|
||||
href="https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/css/bootstrap.min.css"
|
||||
integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T"
|
||||
crossorigin="anonymous"/>
|
||||
crossorigin="anonymous">
|
||||
</link>
|
||||
|
||||
<title>Hello next-gen Kaldi</title>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<body onload="initWebSocket()">
|
||||
|
||||
<h1>Hello next-gen Kaldi</h1>
|
||||
<div class="mb-3">
|
||||
<label for="file" class="form-label">Select file</label>
|
||||
<input class="form-control" type="file" id="file" accept=".wav" onchange="onFileChange()" disabled="true"></input>
|
||||
</div>
|
||||
|
||||
<div class="mb-3">
|
||||
<label for="results" class="form-label">Recognition results</label>
|
||||
<textarea class="form-control" id="results" rows="3"></textarea>
|
||||
</div>
|
||||
|
||||
<!-- Optional JavaScript -->
|
||||
<!-- jQuery first, then Popper.js, then Bootstrap JS -->
|
||||
<script src="https://code.jquery.com/jquery-3.3.1.slim.min.js"
|
||||
integrity="sha384-q8i/X+965DzO0rT7abK41JStQIAqVgRVzpbzo5smXKp4YfRvH+8abtTE1Pi6jizo"
|
||||
crossorigin="anonymous"/>
|
||||
crossorigin="anonymous">
|
||||
</script>
|
||||
|
||||
<script src="https://cdn.jsdelivr.net/npm/popper.js@1.14.7/dist/umd/popper.min.js"
|
||||
integrity="sha384-UO2eT0CpHqdSJQ6hJty5KVphtPhzWj9WO1clHTMGa3JDZwrnQq4sF86dIHNDz0W1"
|
||||
crossorigin="anonymous"/>
|
||||
crossorigin="anonymous">
|
||||
</script>
|
||||
|
||||
<script src="https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/js/bootstrap.min.js"
|
||||
integrity="sha384-JjSmVgyd0p3pXB1rRibZUAYoIIy6OrQ6VrjIEaFf/nJGzIxFDsf4x0xIM+B07jRM"
|
||||
crossorigin="anonymous"/>
|
||||
crossorigin="anonymous">
|
||||
</script>
|
||||
|
||||
<script src="./main.js"> </script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
|
60
egs/librispeech/ASR/transducer_emformer/client/main.js
Normal file
60
egs/librispeech/ASR/transducer_emformer/client/main.js
Normal file
@ -0,0 +1,60 @@
|
||||
/**
|
||||
References
|
||||
https://developer.mozilla.org/en-US/docs/Web/API/FileList
|
||||
https://developer.mozilla.org/en-US/docs/Web/API/FileReader
|
||||
https://javascript.info/arraybuffer-binary-arrays
|
||||
https://developer.mozilla.org/zh-CN/docs/Web/API/WebSocket
|
||||
https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/send
|
||||
*/
|
||||
|
||||
var socket;
|
||||
function initWebSocket() {
|
||||
socket = new WebSocket("ws://localhost:6008/");
|
||||
|
||||
// Connection opened
|
||||
socket.addEventListener(
|
||||
'open',
|
||||
function(event) { document.getElementById('file').disabled = false; });
|
||||
|
||||
// Connection closed
|
||||
socket.addEventListener('close', function(event) {
|
||||
document.getElementById('file').disabled = true;
|
||||
initWebSocket();
|
||||
});
|
||||
|
||||
// Listen for messages
|
||||
socket.addEventListener('message', function(event) {
|
||||
document.getElementById('results').innerHTML = event.data;
|
||||
console.log('Received message: ', event.data);
|
||||
});
|
||||
}
|
||||
|
||||
function onFileChange() {
|
||||
var files = document.getElementById("file").files;
|
||||
|
||||
if (files.length == 0) {
|
||||
console.log('No file selected');
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('files: ' + files);
|
||||
|
||||
const file = files[0];
|
||||
console.log(file);
|
||||
console.log('file.name ' + file.name);
|
||||
console.log('file.type ' + file.type);
|
||||
console.log('file.size ' + file.size);
|
||||
|
||||
let reader = new FileReader();
|
||||
reader.onload = function() {
|
||||
let view = new Int16Array(reader.result);
|
||||
console.log('bytes: ' + view.byteLength);
|
||||
// we assume the input file is a wav file.
|
||||
// TODO: add some checks here.
|
||||
let body = view.subarray(44);
|
||||
socket.send(body);
|
||||
socket.send(JSON.stringify({'eof' : 1}));
|
||||
};
|
||||
|
||||
reader.readAsArrayBuffer(file);
|
||||
}
|
182
egs/librispeech/ASR/transducer_emformer/server.py
Executable file
182
egs/librispeech/ASR/transducer_emformer/server.py
Executable file
@ -0,0 +1,182 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import websockets
|
||||
from streaming_decode import StreamList, get_parser, process_features
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import setup_logger
|
||||
|
||||
g_params = None
|
||||
g_model = None
|
||||
g_sp = None
|
||||
|
||||
|
||||
def build_stream_list():
|
||||
batch_size = 1 # will change it later
|
||||
|
||||
stream_list = StreamList(
|
||||
batch_size=batch_size,
|
||||
context_size=g_params.context_size,
|
||||
decoding_method=g_params.decoding_method,
|
||||
)
|
||||
return stream_list
|
||||
|
||||
|
||||
async def echo(websocket):
|
||||
logging.info(f"connected: {websocket.remote_address}")
|
||||
|
||||
stream_list = build_stream_list()
|
||||
|
||||
# number of frames before subsampling
|
||||
segment_length = g_model.encoder.segment_length
|
||||
|
||||
right_context_length = g_model.encoder.right_context_length
|
||||
|
||||
# We add 3 here since the subsampling method is using
|
||||
# ((len - 1) // 2 - 1) // 2)
|
||||
chunk_length = (segment_length + 3) + right_context_length
|
||||
|
||||
async for message in websocket:
|
||||
if isinstance(message, bytes):
|
||||
samples = torch.frombuffer(message, dtype=torch.int16)
|
||||
samples = samples.to(torch.float32) / 32768
|
||||
stream_list.accept_waveform(
|
||||
audio_samples=[samples],
|
||||
sampling_rate=g_params.sampling_rate,
|
||||
)
|
||||
|
||||
while True:
|
||||
features, active_streams = stream_list.build_batch(
|
||||
chunk_length=chunk_length,
|
||||
segment_length=segment_length,
|
||||
)
|
||||
|
||||
if features is not None:
|
||||
process_features(
|
||||
model=g_model,
|
||||
features=features,
|
||||
streams=active_streams,
|
||||
params=g_params,
|
||||
sp=g_sp,
|
||||
)
|
||||
results = []
|
||||
for stream in stream_list.streams:
|
||||
text = g_sp.decode(stream.decoding_result())
|
||||
results.append(text)
|
||||
await websocket.send(results[0])
|
||||
else:
|
||||
break
|
||||
elif isinstance(message, str):
|
||||
stream_list[0].input_finished()
|
||||
while True:
|
||||
features, active_streams = stream_list.build_batch(
|
||||
chunk_length=chunk_length,
|
||||
segment_length=segment_length,
|
||||
)
|
||||
|
||||
if features is not None:
|
||||
process_features(
|
||||
model=g_model,
|
||||
features=features,
|
||||
streams=active_streams,
|
||||
params=g_params,
|
||||
sp=g_sp,
|
||||
)
|
||||
else:
|
||||
break
|
||||
|
||||
results = []
|
||||
for stream in stream_list.streams:
|
||||
text = g_sp.decode(stream.decoding_result())
|
||||
results.append(text)
|
||||
|
||||
await websocket.send(results[0])
|
||||
await websocket.close()
|
||||
|
||||
logging.info(f"Closed: {websocket.remote_address}")
|
||||
|
||||
|
||||
async def loop():
|
||||
logging.info("started")
|
||||
async with websockets.serve(echo, "", 6008):
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
# Note: params.decoding_method is currently not used.
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-streaming-decode")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
params.device = device
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.avg_last_n > 0:
|
||||
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
global g_params, g_model, g_sp
|
||||
g_params = params
|
||||
g_model = model
|
||||
g_sp = sp
|
||||
|
||||
asyncio.run(loop())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220506)
|
||||
main()
|
@ -233,6 +233,9 @@ class StreamList(object):
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
|
||||
def __getitem__(self, i) -> FeatureExtractionStream:
|
||||
return self.streams[i]
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Return True if all streams have reached end of utterance.
|
||||
@ -667,8 +670,9 @@ def main():
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
params.device = device
|
||||
|
@ -378,6 +378,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.embedding_dim,
|
||||
blank_id=params.blank_id,
|
||||
unk_id=params.unk_id,
|
||||
context_size=params.context_size,
|
||||
)
|
||||
return decoder
|
||||
@ -811,8 +812,9 @@ def run(rank, world_size, args):
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
Loading…
x
Reference in New Issue
Block a user