diff --git a/.flake8 b/.flake8
index a76067aac..89502acd5 100644
--- a/.flake8
+++ b/.flake8
@@ -24,6 +24,7 @@ exclude =
.git,
**/data/**,
icefall/shared/make_kn_lm.py,
+ egs/librispeech/ASR/transducer_emformer/train.py,
icefall/__init__.py
ignore =
diff --git a/egs/librispeech/ASR/transducer_emformer/client/index.html b/egs/librispeech/ASR/transducer_emformer/client/index.html
new file mode 100644
index 000000000..d0fec4fc1
--- /dev/null
+++ b/egs/librispeech/ASR/transducer_emformer/client/index.html
@@ -0,0 +1,62 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Next-gen Kaldi demo
+
+
+
+
+
+
+
+
+ -
+
+
Upload
+
Recognition from a selected file
+
+ -
+
+
-
+
+
Record
+
Recognition from real-time recordings
+
+
+
+
+ Code is available at
+ https://github.com/k2-fsa/icefall/tree/streaming/egs/librispeech/ASR/transducer_emformer
+
+
+
+
+
+
+
+
+
+
diff --git a/egs/librispeech/ASR/transducer_emformer/client/nav-partial.html b/egs/librispeech/ASR/transducer_emformer/client/nav-partial.html
new file mode 100644
index 000000000..513c1511f
--- /dev/null
+++ b/egs/librispeech/ASR/transducer_emformer/client/nav-partial.html
@@ -0,0 +1,22 @@
+
diff --git a/egs/librispeech/ASR/transducer_emformer/client/record.html b/egs/librispeech/ASR/transducer_emformer/client/record.html
new file mode 100644
index 000000000..4a06e0ec9
--- /dev/null
+++ b/egs/librispeech/ASR/transducer_emformer/client/record.html
@@ -0,0 +1,71 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Next-gen Kaldi demo (Upload file for recognition)
+
+
+
+
+
+
+
+ Recognition from real-time recordings
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/egs/librispeech/ASR/transducer_emformer/client/record.js b/egs/librispeech/ASR/transducer_emformer/client/record.js
new file mode 100644
index 000000000..168bfdaa8
--- /dev/null
+++ b/egs/librispeech/ASR/transducer_emformer/client/record.js
@@ -0,0 +1,333 @@
+// see https://mdn.github.io/web-dictaphone/scripts/app.js
+// and https://gist.github.com/meziantou/edb7217fddfbb70e899e
+
+var socket;
+function initWebSocket() {
+ socket = new WebSocket("ws://localhost:6008/");
+
+ // Connection opened
+ socket.addEventListener('open', function(event) {
+ console.log('connected');
+ document.getElementById('record').disabled = false;
+ });
+
+ // Connection closed
+ socket.addEventListener('close', function(event) {
+ console.log('disconnected');
+ document.getElementById('record').disabled = true;
+ initWebSocket();
+ });
+
+ // Listen for messages
+ socket.addEventListener('message', function(event) {
+ document.getElementById('results').innerHTML = event.data;
+ console.log('Received message: ', event.data);
+ });
+}
+
+const recordBtn = document.getElementById('record');
+const stopBtn = document.getElementById('stop');
+const clearBtn = document.getElementById('clear');
+const soundClips = document.getElementById('sound-clips');
+const canvas = document.getElementById('canvas');
+const mainSection = document.querySelector('.container');
+
+stopBtn.disabled = true;
+
+let audioCtx;
+const canvasCtx = canvas.getContext("2d");
+let mediaStream;
+let analyser;
+
+let expectedSampleRate = 16000;
+let recordSampleRate; // the sampleRate of the microphone
+let recorder = null; // the microphone
+let leftchannel = []; // TODO: Use a single channel
+
+let recordingLength = 0; // number of samples so far
+
+clearBtn.onclick =
+ function() { document.getElementById('results').innerHTML = ''; };
+
+// copied/modified from https://mdn.github.io/web-dictaphone/
+// and
+// https://gist.github.com/meziantou/edb7217fddfbb70e899e
+if (navigator.mediaDevices.getUserMedia) {
+ console.log('getUserMedia supported.');
+
+ // see https://w3c.github.io/mediacapture-main/#dom-mediadevices-getusermedia
+ const constraints = {audio : true};
+
+ let onSuccess = function(stream) {
+ if (!audioCtx) {
+ audioCtx = new AudioContext();
+ }
+ console.log(audioCtx);
+ recordSampleRate = audioCtx.sampleRate;
+ console.log('sample rate ' + recordSampleRate);
+
+ // creates an audio node from the microphone incoming stream
+ mediaStream = audioCtx.createMediaStreamSource(stream);
+ console.log(mediaStream);
+
+ // https://developer.mozilla.org/en-US/docs/Web/API/AudioContext/createScriptProcessor
+ // bufferSize: the onaudioprocess event is called when the buffer is full
+ var bufferSize = 2048;
+ var numberOfInputChannels = 2;
+ var numberOfOutputChannels = 2;
+ if (audioCtx.createScriptProcessor) {
+ recorder = audioCtx.createScriptProcessor(
+ bufferSize, numberOfInputChannels, numberOfOutputChannels);
+ } else {
+ recorder = audioCtx.createJavaScriptNode(
+ bufferSize, numberOfInputChannels, numberOfOutputChannels);
+ }
+ console.log(recorder);
+
+ recorder.onaudioprocess = function(e) {
+ let samples = new Float32Array(e.inputBuffer.getChannelData(0))
+ samples = downsampleBuffer(samples, expectedSampleRate);
+
+ let buf = new Int16Array(samples.length);
+ for (var i = 0; i < samples.length; ++i) {
+ let s = samples[i];
+ if (s >= 1)
+ s = 1;
+ else if (s <= -1)
+ s = -1;
+
+ buf[i] = s * 32767;
+ }
+
+ socket.send(buf);
+ leftchannel.push(buf);
+ recordingLength += bufferSize;
+ console.log(recordingLength);
+ };
+
+ visualize(stream);
+ mediaStream.connect(analyser);
+
+ recordBtn.onclick = function() {
+ mediaStream.connect(recorder);
+ mediaStream.connect(analyser);
+ recorder.connect(audioCtx.destination);
+
+ console.log("recorder started");
+ recordBtn.style.background = "red";
+
+ stopBtn.disabled = false;
+ recordBtn.disabled = true;
+ };
+
+ stopBtn.onclick = function() {
+ console.log("recorder stopped");
+ socket.close();
+
+ // stopBtn recording
+ recorder.disconnect(audioCtx.destination);
+ mediaStream.disconnect(recorder);
+ mediaStream.disconnect(analyser);
+
+ recordBtn.style.background = "";
+ recordBtn.style.color = "";
+ // mediaRecorder.requestData();
+
+ stopBtn.disabled = true;
+ recordBtn.disabled = false;
+
+ const clipName =
+ prompt('Enter a name for your sound clip?', 'My unnamed clip');
+
+ const clipContainer = document.createElement('article');
+ const clipLabel = document.createElement('p');
+ const audio = document.createElement('audio');
+ const deleteButton = document.createElement('button');
+ clipContainer.classList.add('clip');
+ audio.setAttribute('controls', '');
+ deleteButton.textContent = 'Delete';
+ deleteButton.className = 'delete';
+
+ if (clipName === null) {
+ clipLabel.textContent = 'My unnamed clip';
+ } else {
+ clipLabel.textContent = clipName;
+ }
+
+ clipContainer.appendChild(audio);
+
+ clipContainer.appendChild(clipLabel);
+ clipContainer.appendChild(deleteButton);
+ soundClips.appendChild(clipContainer);
+
+ audio.controls = true;
+ let samples = flatten(leftchannel);
+ const blob = toWav(samples);
+
+ leftchannel = [];
+ const audioURL = window.URL.createObjectURL(blob);
+ audio.src = audioURL;
+ console.log("recorder stopped");
+
+ deleteButton.onclick = function(e) {
+ let evtTgt = e.target;
+ evtTgt.parentNode.parentNode.removeChild(evtTgt.parentNode);
+ };
+
+ clipLabel.onclick = function() {
+ const existingName = clipLabel.textContent;
+ const newClipName = prompt('Enter a new name for your sound clip?');
+ if (newClipName === null) {
+ clipLabel.textContent = existingName;
+ } else {
+ clipLabel.textContent = newClipName;
+ }
+ };
+ };
+ };
+
+ let onError = function(
+ err) { console.log('The following error occured: ' + err); };
+
+ navigator.mediaDevices.getUserMedia(constraints).then(onSuccess, onError);
+} else {
+ console.log('getUserMedia not supported on your browser!');
+ alert('getUserMedia not supported on your browser!');
+}
+
+function visualize(stream) {
+ if (!audioCtx) {
+ audioCtx = new AudioContext();
+ }
+
+ const source = audioCtx.createMediaStreamSource(stream);
+
+ if (!analyser) {
+ analyser = audioCtx.createAnalyser();
+ analyser.fftSize = 2048;
+ }
+ const bufferLength = analyser.frequencyBinCount;
+ const dataArray = new Uint8Array(bufferLength);
+
+ // source.connect(analyser);
+ // analyser.connect(audioCtx.destination);
+
+ draw()
+
+ function draw() {
+ const WIDTH = canvas.width
+ const HEIGHT = canvas.height;
+
+ requestAnimationFrame(draw);
+
+ analyser.getByteTimeDomainData(dataArray);
+
+ canvasCtx.fillStyle = 'rgb(200, 200, 200)';
+ canvasCtx.fillRect(0, 0, WIDTH, HEIGHT);
+
+ canvasCtx.lineWidth = 2;
+ canvasCtx.strokeStyle = 'rgb(0, 0, 0)';
+
+ canvasCtx.beginPath();
+
+ let sliceWidth = WIDTH * 1.0 / bufferLength;
+ let x = 0;
+
+ for (let i = 0; i < bufferLength; i++) {
+
+ let v = dataArray[i] / 128.0;
+ let y = v * HEIGHT / 2;
+
+ if (i === 0) {
+ canvasCtx.moveTo(x, y);
+ } else {
+ canvasCtx.lineTo(x, y);
+ }
+
+ x += sliceWidth;
+ }
+
+ canvasCtx.lineTo(canvas.width, canvas.height / 2);
+ canvasCtx.stroke();
+ }
+}
+
+window.onresize = function() { canvas.width = mainSection.offsetWidth; };
+
+window.onresize();
+
+// this function is copied/modified from
+// https://gist.github.com/meziantou/edb7217fddfbb70e899e
+function flatten(listOfSamples) {
+ let n = 0;
+ for (let i = 0; i < listOfSamples.length; ++i) {
+ n += listOfSamples[i].length;
+ }
+ let ans = new Int16Array(n);
+
+ let offset = 0;
+ for (let i = 0; i < listOfSamples.length; ++i) {
+ ans.set(listOfSamples[i], offset);
+ offset += listOfSamples[i].length;
+ }
+ return ans;
+}
+
+// this function is copied/modified from
+// https://gist.github.com/meziantou/edb7217fddfbb70e899e
+function toWav(samples) {
+ let buf = new ArrayBuffer(44 + samples.length * 2);
+ var view = new DataView(buf);
+
+ // http://soundfile.sapp.org/doc/WaveFormat/
+ // F F I R
+ view.setUint32(0, 0x46464952, true); // chunkID
+ view.setUint32(4, 36 + samples.length * 2, true); // chunkSize
+ // E V A W
+ view.setUint32(8, 0x45564157, true); // format
+ //
+ // t m f
+ view.setUint32(12, 0x20746d66, true); // subchunk1ID
+ view.setUint32(16, 16, true); // subchunk1Size, 16 for PCM
+ view.setUint32(20, 1, true); // audioFormat, 1 for PCM
+ view.setUint16(22, 1, true); // numChannels: 1 channel
+ view.setUint32(24, expectedSampleRate, true); // sampleRate
+ view.setUint32(28, expectedSampleRate * 2, true); // byteRate
+ view.setUint16(32, 2, true); // blockAlign
+ view.setUint16(34, 16, true); // bitsPerSample
+ view.setUint32(36, 0x61746164, true); // Subchunk2ID
+ view.setUint32(40, samples.length * 2, true); // subchunk2Size
+
+ let offset = 44;
+ for (let i = 0; i < samples.length; ++i) {
+ view.setInt16(offset, samples[i], true);
+ offset += 2;
+ }
+
+ return new Blob([ view ], {type : 'audio/wav'});
+}
+
+// this function is copied from
+// https://github.com/awslabs/aws-lex-browser-audio-capture/blob/master/lib/worker.js#L46
+function downsampleBuffer(buffer, exportSampleRate) {
+ if (exportSampleRate === recordSampleRate) {
+ return buffer;
+ }
+ var sampleRateRatio = recordSampleRate / exportSampleRate;
+ var newLength = Math.round(buffer.length / sampleRateRatio);
+ var result = new Float32Array(newLength);
+ var offsetResult = 0;
+ var offsetBuffer = 0;
+ while (offsetResult < result.length) {
+ var nextOffsetBuffer = Math.round((offsetResult + 1) * sampleRateRatio);
+ var accum = 0, count = 0;
+ for (var i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) {
+ accum += buffer[i];
+ count++;
+ }
+ result[offsetResult] = accum / count;
+ offsetResult++;
+ offsetBuffer = nextOffsetBuffer;
+ }
+ return result;
+};
diff --git a/egs/librispeech/ASR/transducer_emformer/client/upload.html b/egs/librispeech/ASR/transducer_emformer/client/upload.html
new file mode 100644
index 000000000..afc1882a3
--- /dev/null
+++ b/egs/librispeech/ASR/transducer_emformer/client/upload.html
@@ -0,0 +1,58 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Next-gen Kaldi demo (Upload file for recognition)
+
+
+
+
+
+
+
+ Recognition from a selected file
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/egs/librispeech/ASR/transducer_emformer/client/upload.js b/egs/librispeech/ASR/transducer_emformer/client/upload.js
new file mode 100644
index 000000000..a2b0f8644
--- /dev/null
+++ b/egs/librispeech/ASR/transducer_emformer/client/upload.js
@@ -0,0 +1,60 @@
+/**
+References
+https://developer.mozilla.org/en-US/docs/Web/API/FileList
+https://developer.mozilla.org/en-US/docs/Web/API/FileReader
+https://javascript.info/arraybuffer-binary-arrays
+https://developer.mozilla.org/zh-CN/docs/Web/API/WebSocket
+https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/send
+*/
+
+var socket;
+function initWebSocket() {
+ socket = new WebSocket("ws://localhost:6008/");
+
+ // Connection opened
+ socket.addEventListener(
+ 'open',
+ function(event) { document.getElementById('file').disabled = false; });
+
+ // Connection closed
+ socket.addEventListener('close', function(event) {
+ document.getElementById('file').disabled = true;
+ initWebSocket();
+ });
+
+ // Listen for messages
+ socket.addEventListener('message', function(event) {
+ document.getElementById('results').innerHTML = event.data;
+ console.log('Received message: ', event.data);
+ });
+}
+
+function onFileChange() {
+ var files = document.getElementById("file").files;
+
+ if (files.length == 0) {
+ console.log('No file selected');
+ return;
+ }
+
+ console.log('files: ' + files);
+
+ const file = files[0];
+ console.log(file);
+ console.log('file.name ' + file.name);
+ console.log('file.type ' + file.type);
+ console.log('file.size ' + file.size);
+
+ let reader = new FileReader();
+ reader.onload = function() {
+ let view = new Uint8Array(reader.result);
+ console.log('bytes: ' + view.byteLength);
+ // we assume the input file is a wav file.
+ // TODO: add some checks here.
+ let body = view.subarray(44);
+ socket.send(body);
+ socket.send(JSON.stringify({'eof' : 1}));
+ };
+
+ reader.readAsArrayBuffer(file);
+}
diff --git a/egs/librispeech/ASR/transducer_emformer/server.py b/egs/librispeech/ASR/transducer_emformer/server.py
new file mode 100755
index 000000000..35f66f60f
--- /dev/null
+++ b/egs/librispeech/ASR/transducer_emformer/server.py
@@ -0,0 +1,182 @@
+#!/usr/bin/env python3
+import asyncio
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+import websockets
+from streaming_decode import StreamList, get_parser, process_features
+from train import get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import setup_logger
+
+g_params = None
+g_model = None
+g_sp = None
+
+
+def build_stream_list():
+ batch_size = 1 # will change it later
+
+ stream_list = StreamList(
+ batch_size=batch_size,
+ context_size=g_params.context_size,
+ decoding_method=g_params.decoding_method,
+ )
+ return stream_list
+
+
+async def echo(websocket):
+ logging.info(f"connected: {websocket.remote_address}")
+
+ stream_list = build_stream_list()
+
+ # number of frames before subsampling
+ segment_length = g_model.encoder.segment_length
+
+ right_context_length = g_model.encoder.right_context_length
+
+ # We add 3 here since the subsampling method is using
+ # ((len - 1) // 2 - 1) // 2)
+ chunk_length = (segment_length + 3) + right_context_length
+
+ async for message in websocket:
+ if isinstance(message, bytes):
+ samples = torch.frombuffer(message, dtype=torch.int16)
+ samples = samples.to(torch.float32) / 32768
+ stream_list.accept_waveform(
+ audio_samples=[samples],
+ sampling_rate=g_params.sampling_rate,
+ )
+
+ while True:
+ features, active_streams = stream_list.build_batch(
+ chunk_length=chunk_length,
+ segment_length=segment_length,
+ )
+
+ if features is not None:
+ process_features(
+ model=g_model,
+ features=features,
+ streams=active_streams,
+ params=g_params,
+ sp=g_sp,
+ )
+ results = []
+ for stream in stream_list.streams:
+ text = g_sp.decode(stream.decoding_result())
+ results.append(text)
+ await websocket.send(results[0])
+ else:
+ break
+ elif isinstance(message, str):
+ stream_list[0].input_finished()
+ while True:
+ features, active_streams = stream_list.build_batch(
+ chunk_length=chunk_length,
+ segment_length=segment_length,
+ )
+
+ if features is not None:
+ process_features(
+ model=g_model,
+ features=features,
+ streams=active_streams,
+ params=g_params,
+ sp=g_sp,
+ )
+ else:
+ break
+
+ results = []
+ for stream in stream_list.streams:
+ text = g_sp.decode(stream.decoding_result())
+ results.append(text)
+
+ await websocket.send(results[0])
+ await websocket.close()
+
+ logging.info(f"Closed: {websocket.remote_address}")
+
+
+async def loop():
+ logging.info("started")
+ async with websockets.serve(echo, "", 6008):
+ await asyncio.Future() # run forever
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ # Note: params.decoding_method is currently not used.
+ params.res_dir = params.exp_dir / "streaming" / params.decoding_method
+
+ setup_logger(f"{params.res_dir}/log-streaming-decode")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ params.device = device
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if params.avg_last_n > 0:
+ filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ global g_params, g_model, g_sp
+ g_params = params
+ g_model = model
+ g_sp = sp
+
+ asyncio.run(loop())
+
+
+if __name__ == "__main__":
+ torch.manual_seed(20220506)
+ main()
diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py
index 8ebfbb210..2064bd344 100755
--- a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py
+++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py
@@ -233,6 +233,9 @@ class StreamList(object):
for _ in range(batch_size)
]
+ def __getitem__(self, i) -> FeatureExtractionStream:
+ return self.streams[i]
+
@property
def done(self) -> bool:
"""Return True if all streams have reached end of utterance.
@@ -667,8 +670,9 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
- # is defined in local/train_bpe_model.py
+ # and are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
params.vocab_size = sp.get_piece_size()
params.device = device
diff --git a/egs/librispeech/ASR/transducer_emformer/train.py b/egs/librispeech/ASR/transducer_emformer/train.py
index 9798fe5e6..dae30f91b 100755
--- a/egs/librispeech/ASR/transducer_emformer/train.py
+++ b/egs/librispeech/ASR/transducer_emformer/train.py
@@ -378,6 +378,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
+ unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
@@ -811,8 +812,9 @@ def run(rank, world_size, args):
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
- # is defined in local/train_bpe_model.py
+ # and are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
params.vocab_size = sp.get_piece_size()
logging.info(params)