mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
Add greedy search for streaming zipformer CTC.
This commit is contained in:
parent
29130fbf27
commit
ac1e99f6c1
20
.github/scripts/multi-zh-hans.sh
vendored
20
.github/scripts/multi-zh-hans.sh
vendored
@ -50,6 +50,26 @@ log "----------------------------------------"
|
|||||||
|
|
||||||
ls -lh $repo/exp/
|
ls -lh $repo/exp/
|
||||||
|
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Test exported streaming ONNX CTC models (greedy search) "
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
|
||||||
|
test_wavs=(
|
||||||
|
DEV_T0000000000.wav
|
||||||
|
DEV_T0000000001.wav
|
||||||
|
DEV_T0000000002.wav
|
||||||
|
TEST_MEETING_T0000000113.wav
|
||||||
|
TEST_MEETING_T0000000219.wav
|
||||||
|
TEST_MEETING_T0000000351.wav
|
||||||
|
)
|
||||||
|
|
||||||
|
for w in ${test_wavs[@]}; do
|
||||||
|
./zipformer/onnx_pretrained-streaming-ctc.py \
|
||||||
|
--model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||||
|
$repo/test_wavs/$w
|
||||||
|
done
|
||||||
|
|
||||||
log "Upload onnx CTC models to huggingface"
|
log "Upload onnx CTC models to huggingface"
|
||||||
url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
|
url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone $url
|
GIT_LFS_SKIP_SMUDGE=1 git clone $url
|
||||||
|
1
.github/workflows/multi-zh-hans.yml
vendored
1
.github/workflows/multi-zh-hans.yml
vendored
@ -4,7 +4,6 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
- streaming-ctc-decoding
|
|
||||||
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ You can use either the ``int8.onnx`` model or just the ``.onnx`` model.
|
|||||||
|
|
||||||
./zipformer/onnx_pretrained-streaming-ctc.py \
|
./zipformer/onnx_pretrained-streaming-ctc.py \
|
||||||
--model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
|
--model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||||
$repo/test_wavs/DEV_T0000000001.wav
|
$repo/test_wavs/DEV_T0000000001.wav
|
||||||
|
|
||||||
Note: Even though this script only supports decoding a single file,
|
Note: Even though this script only supports decoding a single file,
|
||||||
@ -58,9 +58,8 @@ the exported ONNX models do support batch processing.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import k2
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
@ -111,7 +110,7 @@ class OnnxModel:
|
|||||||
|
|
||||||
self.init_model(model_filename)
|
self.init_model(model_filename)
|
||||||
|
|
||||||
def init_model(self, encoder_model_filename: str):
|
def init_model(self, model_filename: str):
|
||||||
self.model = ort.InferenceSession(
|
self.model = ort.InferenceSession(
|
||||||
model_filename,
|
model_filename,
|
||||||
sess_options=self.session_opts,
|
sess_options=self.session_opts,
|
||||||
@ -207,7 +206,7 @@ class OnnxModel:
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
||||||
model_input = {"x": x.numpy()}
|
model_input = {"x": x.numpy()}
|
||||||
model_output = ["model_out"]
|
model_output = ["log_probs"]
|
||||||
|
|
||||||
def build_inputs_outputs(tensors, i):
|
def build_inputs_outputs(tensors, i):
|
||||||
assert len(tensors) == 6, len(tensors)
|
assert len(tensors) == 6, len(tensors)
|
||||||
@ -262,18 +261,18 @@ class OnnxModel:
|
|||||||
def _update_states(self, states: List[np.ndarray]):
|
def _update_states(self, states: List[np.ndarray]):
|
||||||
self.states = states
|
self.states = states
|
||||||
|
|
||||||
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
A 3-D tensor of shape (N, T, C)
|
A 3-D tensor of shape (N, T, C)
|
||||||
Returns:
|
Returns:
|
||||||
Return a 3-D tensor of shape (N, T', joiner_dim) where
|
Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size)
|
||||||
T' is usually equal to ((T-7)//2+1)//2
|
where T' is usually equal to ((T-7)//2 - 3)//2
|
||||||
"""
|
"""
|
||||||
model_input, model_output_names = self._build_model_input_output(x)
|
model_input, model_output_names = self._build_model_input_output(x)
|
||||||
|
|
||||||
out = self.encoder.run(model_output_names, model_input)
|
out = self.model.run(model_output_names, model_input)
|
||||||
|
|
||||||
self._update_states(out[1:])
|
self._update_states(out[1:])
|
||||||
|
|
||||||
@ -323,51 +322,24 @@ def create_streaming_feature_extractor() -> OnlineFeature:
|
|||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
model: OnnxModel,
|
log_probs: torch.Tensor,
|
||||||
model_out: torch.Tensor,
|
|
||||||
context_size: int,
|
|
||||||
decoder_out: Optional[torch.Tensor] = None,
|
|
||||||
hyp: Optional[List[int]] = None,
|
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
"""Greedy search for a single utterance.
|
||||||
Args:
|
Args:
|
||||||
model:
|
log_probs:
|
||||||
The transducer model.
|
A 3-D tensor of shape (1, T, vocab_size)
|
||||||
model_out:
|
|
||||||
A 3-D tensor of shape (1, T, joiner_dim)
|
|
||||||
context_size:
|
|
||||||
The context size of the decoder model.
|
|
||||||
decoder_out:
|
|
||||||
Optional. Decoder output of the previous chunk.
|
|
||||||
hyp:
|
|
||||||
Decoding results for previous chunks.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded results so far.
|
Return the decoded result.
|
||||||
"""
|
"""
|
||||||
|
assert log_probs.ndim == 3, log_probs.shape
|
||||||
|
assert log_probs.shape[0] == 1, log_probs.shape
|
||||||
|
|
||||||
|
max_indexes = log_probs[0].argmax(dim=1)
|
||||||
|
unique_indexes = torch.unique_consecutive(max_indexes)
|
||||||
|
|
||||||
blank_id = 0
|
blank_id = 0
|
||||||
|
unique_indexes = unique_indexes[unique_indexes != blank_id]
|
||||||
if decoder_out is None:
|
return unique_indexes.tolist()
|
||||||
assert hyp is None, hyp
|
|
||||||
hyp = [blank_id] * context_size
|
|
||||||
decoder_input = torch.tensor([hyp], dtype=torch.int64)
|
|
||||||
decoder_out = model.run_decoder(decoder_input)
|
|
||||||
else:
|
|
||||||
assert hyp is not None, hyp
|
|
||||||
|
|
||||||
model_out = model_out.squeeze(0)
|
|
||||||
T = model_out.size(0)
|
|
||||||
for t in range(T):
|
|
||||||
cur_encoder_out = model_out[t : t + 1]
|
|
||||||
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
|
|
||||||
y = joiner_out.argmax(dim=0).item()
|
|
||||||
if y != blank_id:
|
|
||||||
hyp.append(y)
|
|
||||||
decoder_input = hyp[-context_size:]
|
|
||||||
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
|
|
||||||
decoder_out = model.run_decoder(decoder_input)
|
|
||||||
|
|
||||||
return hyp, decoder_out
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -376,11 +348,7 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logging.info(vars(args))
|
logging.info(vars(args))
|
||||||
|
|
||||||
model = OnnxModel(
|
model = OnnxModel(model_filename=args.model_filename)
|
||||||
encoder_model_filename=args.encoder_model_filename,
|
|
||||||
decoder_model_filename=args.decoder_model_filename,
|
|
||||||
joiner_model_filename=args.joiner_model_filename,
|
|
||||||
)
|
|
||||||
|
|
||||||
sample_rate = 16000
|
sample_rate = 16000
|
||||||
|
|
||||||
@ -400,9 +368,7 @@ def main():
|
|||||||
segment = model.segment
|
segment = model.segment
|
||||||
offset = model.offset
|
offset = model.offset
|
||||||
|
|
||||||
context_size = model.context_size
|
hyp = []
|
||||||
hyp = None
|
|
||||||
decoder_out = None
|
|
||||||
|
|
||||||
chunk = int(1 * sample_rate) # 1 second
|
chunk = int(1 * sample_rate) # 1 second
|
||||||
start = 0
|
start = 0
|
||||||
@ -423,18 +389,28 @@ def main():
|
|||||||
num_processed_frames += offset
|
num_processed_frames += offset
|
||||||
frames = torch.cat(frames, dim=0)
|
frames = torch.cat(frames, dim=0)
|
||||||
frames = frames.unsqueeze(0)
|
frames = frames.unsqueeze(0)
|
||||||
model_out = model.run_encoder(frames)
|
log_probs = model(frames)
|
||||||
hyp = greedy_search(
|
|
||||||
model,
|
|
||||||
model_out,
|
|
||||||
hyp,
|
|
||||||
)
|
|
||||||
|
|
||||||
token_table = k2.SymbolTable.from_file(args.tokens)
|
hyp += greedy_search(log_probs)
|
||||||
|
|
||||||
text = ""
|
# To handle byte-level BPE, we convert string tokens to utf-8 encoded bytes
|
||||||
for i in hyp[context_size:]:
|
id2token = {}
|
||||||
text += token_table[i]
|
with open(args.tokens, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
token, idx = line.split()
|
||||||
|
if token[:3] == "<0x" and token[-1] == ">":
|
||||||
|
token = int(token[1:-1], base=16)
|
||||||
|
assert 0 <= token < 256, token
|
||||||
|
token = token.to_bytes(1, byteorder="little")
|
||||||
|
else:
|
||||||
|
token = token.encode(encoding="utf-8")
|
||||||
|
|
||||||
|
id2token[int(idx)] = token
|
||||||
|
|
||||||
|
text = b""
|
||||||
|
for i in hyp:
|
||||||
|
text += id2token[i]
|
||||||
|
text = text.decode(encoding="utf-8")
|
||||||
text = text.replace("▁", " ").strip()
|
text = text.replace("▁", " ").strip()
|
||||||
|
|
||||||
logging.info(args.sound_file)
|
logging.info(args.sound_file)
|
||||||
|
@ -326,7 +326,7 @@ class OnnxModel:
|
|||||||
A 3-D tensor of shape (N, T, C)
|
A 3-D tensor of shape (N, T, C)
|
||||||
Returns:
|
Returns:
|
||||||
Return a 3-D tensor of shape (N, T', joiner_dim) where
|
Return a 3-D tensor of shape (N, T', joiner_dim) where
|
||||||
T' is usually equal to ((T-7)//2+1)//2
|
T' is usually equal to ((T-7)//2-3)//2
|
||||||
"""
|
"""
|
||||||
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user