Add greedy search for streaming zipformer CTC.

This commit is contained in:
Fangjun Kuang 2023-12-13 17:15:09 +08:00
parent 29130fbf27
commit ac1e99f6c1
4 changed files with 63 additions and 68 deletions

View File

@ -50,6 +50,26 @@ log "----------------------------------------"
ls -lh $repo/exp/
log "------------------------------------------------------------"
log "Test exported streaming ONNX CTC models (greedy search) "
log "------------------------------------------------------------"
test_wavs=(
DEV_T0000000000.wav
DEV_T0000000001.wav
DEV_T0000000002.wav
TEST_MEETING_T0000000113.wav
TEST_MEETING_T0000000219.wav
TEST_MEETING_T0000000351.wav
)
for w in ${test_wavs[@]}; do
./zipformer/onnx_pretrained-streaming-ctc.py \
--model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
--tokens $repo/data/lang_bpe_2000/tokens.txt \
$repo/test_wavs/$w
done
log "Upload onnx CTC models to huggingface"
url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
GIT_LFS_SKIP_SMUDGE=1 git clone $url

View File

@ -4,7 +4,6 @@ on:
push:
branches:
- master
- streaming-ctc-decoding
workflow_dispatch:

View File

@ -49,7 +49,7 @@ You can use either the ``int8.onnx`` model or just the ``.onnx`` model.
./zipformer/onnx_pretrained-streaming-ctc.py \
--model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--tokens $repo/data/lang_bpe_2000/tokens.txt \
$repo/test_wavs/DEV_T0000000001.wav
Note: Even though this script only supports decoding a single file,
@ -58,9 +58,8 @@ the exported ONNX models do support batch processing.
import argparse
import logging
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple
import k2
import numpy as np
import onnxruntime as ort
import torch
@ -111,7 +110,7 @@ class OnnxModel:
self.init_model(model_filename)
def init_model(self, encoder_model_filename: str):
def init_model(self, model_filename: str):
self.model = ort.InferenceSession(
model_filename,
sess_options=self.session_opts,
@ -207,7 +206,7 @@ class OnnxModel:
x: torch.Tensor,
) -> Tuple[Dict[str, np.ndarray], List[str]]:
model_input = {"x": x.numpy()}
model_output = ["model_out"]
model_output = ["log_probs"]
def build_inputs_outputs(tensors, i):
assert len(tensors) == 6, len(tensors)
@ -262,18 +261,18 @@ class OnnxModel:
def _update_states(self, states: List[np.ndarray]):
self.states = states
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
Returns:
Return a 3-D tensor of shape (N, T', joiner_dim) where
T' is usually equal to ((T-7)//2+1)//2
Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size)
where T' is usually equal to ((T-7)//2 - 3)//2
"""
model_input, model_output_names = self._build_model_input_output(x)
out = self.encoder.run(model_output_names, model_input)
out = self.model.run(model_output_names, model_input)
self._update_states(out[1:])
@ -323,51 +322,24 @@ def create_streaming_feature_extractor() -> OnlineFeature:
def greedy_search(
model: OnnxModel,
model_out: torch.Tensor,
context_size: int,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
log_probs: torch.Tensor,
) -> List[int]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
"""Greedy search for a single utterance.
Args:
model:
The transducer model.
model_out:
A 3-D tensor of shape (1, T, joiner_dim)
context_size:
The context size of the decoder model.
decoder_out:
Optional. Decoder output of the previous chunk.
hyp:
Decoding results for previous chunks.
log_probs:
A 3-D tensor of shape (1, T, vocab_size)
Returns:
Return the decoded results so far.
Return the decoded result.
"""
assert log_probs.ndim == 3, log_probs.shape
assert log_probs.shape[0] == 1, log_probs.shape
max_indexes = log_probs[0].argmax(dim=1)
unique_indexes = torch.unique_consecutive(max_indexes)
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor([hyp], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
else:
assert hyp is not None, hyp
model_out = model_out.squeeze(0)
T = model_out.size(0)
for t in range(T):
cur_encoder_out = model_out[t : t + 1]
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
y = joiner_out.argmax(dim=0).item()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
return hyp, decoder_out
unique_indexes = unique_indexes[unique_indexes != blank_id]
return unique_indexes.tolist()
@torch.no_grad()
@ -376,11 +348,7 @@ def main():
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
model = OnnxModel(model_filename=args.model_filename)
sample_rate = 16000
@ -400,9 +368,7 @@ def main():
segment = model.segment
offset = model.offset
context_size = model.context_size
hyp = None
decoder_out = None
hyp = []
chunk = int(1 * sample_rate) # 1 second
start = 0
@ -423,18 +389,28 @@ def main():
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
frames = frames.unsqueeze(0)
model_out = model.run_encoder(frames)
hyp = greedy_search(
model,
model_out,
hyp,
)
log_probs = model(frames)
token_table = k2.SymbolTable.from_file(args.tokens)
hyp += greedy_search(log_probs)
text = ""
for i in hyp[context_size:]:
text += token_table[i]
# To handle byte-level BPE, we convert string tokens to utf-8 encoded bytes
id2token = {}
with open(args.tokens, encoding="utf-8") as f:
for line in f:
token, idx = line.split()
if token[:3] == "<0x" and token[-1] == ">":
token = int(token[1:-1], base=16)
assert 0 <= token < 256, token
token = token.to_bytes(1, byteorder="little")
else:
token = token.encode(encoding="utf-8")
id2token[int(idx)] = token
text = b""
for i in hyp:
text += id2token[i]
text = text.decode(encoding="utf-8")
text = text.replace("", " ").strip()
logging.info(args.sound_file)

View File

@ -326,7 +326,7 @@ class OnnxModel:
A 3-D tensor of shape (N, T, C)
Returns:
Return a 3-D tensor of shape (N, T', joiner_dim) where
T' is usually equal to ((T-7)//2+1)//2
T' is usually equal to ((T-7)//2-3)//2
"""
encoder_input, encoder_output_names = self._build_encoder_input_output(x)