mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
Add greedy search for streaming zipformer CTC.
This commit is contained in:
parent
29130fbf27
commit
ac1e99f6c1
20
.github/scripts/multi-zh-hans.sh
vendored
20
.github/scripts/multi-zh-hans.sh
vendored
@ -50,6 +50,26 @@ log "----------------------------------------"
|
||||
|
||||
ls -lh $repo/exp/
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Test exported streaming ONNX CTC models (greedy search) "
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
test_wavs=(
|
||||
DEV_T0000000000.wav
|
||||
DEV_T0000000001.wav
|
||||
DEV_T0000000002.wav
|
||||
TEST_MEETING_T0000000113.wav
|
||||
TEST_MEETING_T0000000219.wav
|
||||
TEST_MEETING_T0000000351.wav
|
||||
)
|
||||
|
||||
for w in ${test_wavs[@]}; do
|
||||
./zipformer/onnx_pretrained-streaming-ctc.py \
|
||||
--model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
|
||||
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||
$repo/test_wavs/$w
|
||||
done
|
||||
|
||||
log "Upload onnx CTC models to huggingface"
|
||||
url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $url
|
||||
|
1
.github/workflows/multi-zh-hans.yml
vendored
1
.github/workflows/multi-zh-hans.yml
vendored
@ -4,7 +4,6 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- streaming-ctc-decoding
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
|
@ -49,7 +49,7 @@ You can use either the ``int8.onnx`` model or just the ``.onnx`` model.
|
||||
|
||||
./zipformer/onnx_pretrained-streaming-ctc.py \
|
||||
--model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||
$repo/test_wavs/DEV_T0000000001.wav
|
||||
|
||||
Note: Even though this script only supports decoding a single file,
|
||||
@ -58,9 +58,8 @@ the exported ONNX models do support batch processing.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
@ -111,7 +110,7 @@ class OnnxModel:
|
||||
|
||||
self.init_model(model_filename)
|
||||
|
||||
def init_model(self, encoder_model_filename: str):
|
||||
def init_model(self, model_filename: str):
|
||||
self.model = ort.InferenceSession(
|
||||
model_filename,
|
||||
sess_options=self.session_opts,
|
||||
@ -207,7 +206,7 @@ class OnnxModel:
|
||||
x: torch.Tensor,
|
||||
) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
||||
model_input = {"x": x.numpy()}
|
||||
model_output = ["model_out"]
|
||||
model_output = ["log_probs"]
|
||||
|
||||
def build_inputs_outputs(tensors, i):
|
||||
assert len(tensors) == 6, len(tensors)
|
||||
@ -262,18 +261,18 @@ class OnnxModel:
|
||||
def _update_states(self, states: List[np.ndarray]):
|
||||
self.states = states
|
||||
|
||||
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
Returns:
|
||||
Return a 3-D tensor of shape (N, T', joiner_dim) where
|
||||
T' is usually equal to ((T-7)//2+1)//2
|
||||
Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size)
|
||||
where T' is usually equal to ((T-7)//2 - 3)//2
|
||||
"""
|
||||
model_input, model_output_names = self._build_model_input_output(x)
|
||||
|
||||
out = self.encoder.run(model_output_names, model_input)
|
||||
out = self.model.run(model_output_names, model_input)
|
||||
|
||||
self._update_states(out[1:])
|
||||
|
||||
@ -323,51 +322,24 @@ def create_streaming_feature_extractor() -> OnlineFeature:
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: OnnxModel,
|
||||
model_out: torch.Tensor,
|
||||
context_size: int,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
log_probs: torch.Tensor,
|
||||
) -> List[int]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
"""Greedy search for a single utterance.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
model_out:
|
||||
A 3-D tensor of shape (1, T, joiner_dim)
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
decoder_out:
|
||||
Optional. Decoder output of the previous chunk.
|
||||
hyp:
|
||||
Decoding results for previous chunks.
|
||||
log_probs:
|
||||
A 3-D tensor of shape (1, T, vocab_size)
|
||||
Returns:
|
||||
Return the decoded results so far.
|
||||
Return the decoded result.
|
||||
"""
|
||||
assert log_probs.ndim == 3, log_probs.shape
|
||||
assert log_probs.shape[0] == 1, log_probs.shape
|
||||
|
||||
max_indexes = log_probs[0].argmax(dim=1)
|
||||
unique_indexes = torch.unique_consecutive(max_indexes)
|
||||
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor([hyp], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
else:
|
||||
assert hyp is not None, hyp
|
||||
|
||||
model_out = model_out.squeeze(0)
|
||||
T = model_out.size(0)
|
||||
for t in range(T):
|
||||
cur_encoder_out = model_out[t : t + 1]
|
||||
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
return hyp, decoder_out
|
||||
unique_indexes = unique_indexes[unique_indexes != blank_id]
|
||||
return unique_indexes.tolist()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -376,11 +348,7 @@ def main():
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = OnnxModel(
|
||||
encoder_model_filename=args.encoder_model_filename,
|
||||
decoder_model_filename=args.decoder_model_filename,
|
||||
joiner_model_filename=args.joiner_model_filename,
|
||||
)
|
||||
model = OnnxModel(model_filename=args.model_filename)
|
||||
|
||||
sample_rate = 16000
|
||||
|
||||
@ -400,9 +368,7 @@ def main():
|
||||
segment = model.segment
|
||||
offset = model.offset
|
||||
|
||||
context_size = model.context_size
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
hyp = []
|
||||
|
||||
chunk = int(1 * sample_rate) # 1 second
|
||||
start = 0
|
||||
@ -423,18 +389,28 @@ def main():
|
||||
num_processed_frames += offset
|
||||
frames = torch.cat(frames, dim=0)
|
||||
frames = frames.unsqueeze(0)
|
||||
model_out = model.run_encoder(frames)
|
||||
hyp = greedy_search(
|
||||
model,
|
||||
model_out,
|
||||
hyp,
|
||||
)
|
||||
log_probs = model(frames)
|
||||
|
||||
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||
hyp += greedy_search(log_probs)
|
||||
|
||||
text = ""
|
||||
for i in hyp[context_size:]:
|
||||
text += token_table[i]
|
||||
# To handle byte-level BPE, we convert string tokens to utf-8 encoded bytes
|
||||
id2token = {}
|
||||
with open(args.tokens, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
token, idx = line.split()
|
||||
if token[:3] == "<0x" and token[-1] == ">":
|
||||
token = int(token[1:-1], base=16)
|
||||
assert 0 <= token < 256, token
|
||||
token = token.to_bytes(1, byteorder="little")
|
||||
else:
|
||||
token = token.encode(encoding="utf-8")
|
||||
|
||||
id2token[int(idx)] = token
|
||||
|
||||
text = b""
|
||||
for i in hyp:
|
||||
text += id2token[i]
|
||||
text = text.decode(encoding="utf-8")
|
||||
text = text.replace("▁", " ").strip()
|
||||
|
||||
logging.info(args.sound_file)
|
||||
|
@ -326,7 +326,7 @@ class OnnxModel:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
Returns:
|
||||
Return a 3-D tensor of shape (N, T', joiner_dim) where
|
||||
T' is usually equal to ((T-7)//2+1)//2
|
||||
T' is usually equal to ((T-7)//2-3)//2
|
||||
"""
|
||||
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user