From ac1e99f6c141344ba213d6384c8743898ca665c9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 13 Dec 2023 17:15:09 +0800 Subject: [PATCH] Add greedy search for streaming zipformer CTC. --- .github/scripts/multi-zh-hans.sh | 20 ++++ .github/workflows/multi-zh-hans.yml | 1 - .../onnx_pretrained-streaming-ctc.py | 108 +++++++----------- .../zipformer/onnx_pretrained-streaming.py | 2 +- 4 files changed, 63 insertions(+), 68 deletions(-) diff --git a/.github/scripts/multi-zh-hans.sh b/.github/scripts/multi-zh-hans.sh index ab0c4471b..427d8887b 100755 --- a/.github/scripts/multi-zh-hans.sh +++ b/.github/scripts/multi-zh-hans.sh @@ -50,6 +50,26 @@ log "----------------------------------------" ls -lh $repo/exp/ +log "------------------------------------------------------------" +log "Test exported streaming ONNX CTC models (greedy search) " +log "------------------------------------------------------------" + +test_wavs=( +DEV_T0000000000.wav +DEV_T0000000001.wav +DEV_T0000000002.wav +TEST_MEETING_T0000000113.wav +TEST_MEETING_T0000000219.wav +TEST_MEETING_T0000000351.wav +) + +for w in ${test_wavs[@]}; do + ./zipformer/onnx_pretrained-streaming-ctc.py \ + --model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/$w +done + log "Upload onnx CTC models to huggingface" url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 GIT_LFS_SKIP_SMUDGE=1 git clone $url diff --git a/.github/workflows/multi-zh-hans.yml b/.github/workflows/multi-zh-hans.yml index f0222869d..9081047de 100644 --- a/.github/workflows/multi-zh-hans.yml +++ b/.github/workflows/multi-zh-hans.yml @@ -4,7 +4,6 @@ on: push: branches: - master - - streaming-ctc-decoding workflow_dispatch: diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py index c1d8195ad..e07a64d1b 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py @@ -49,7 +49,7 @@ You can use either the ``int8.onnx`` model or just the ``.onnx`` model. ./zipformer/onnx_pretrained-streaming-ctc.py \ --model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ $repo/test_wavs/DEV_T0000000001.wav Note: Even though this script only supports decoding a single file, @@ -58,9 +58,8 @@ the exported ONNX models do support batch processing. import argparse import logging -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple -import k2 import numpy as np import onnxruntime as ort import torch @@ -111,7 +110,7 @@ class OnnxModel: self.init_model(model_filename) - def init_model(self, encoder_model_filename: str): + def init_model(self, model_filename: str): self.model = ort.InferenceSession( model_filename, sess_options=self.session_opts, @@ -207,7 +206,7 @@ class OnnxModel: x: torch.Tensor, ) -> Tuple[Dict[str, np.ndarray], List[str]]: model_input = {"x": x.numpy()} - model_output = ["model_out"] + model_output = ["log_probs"] def build_inputs_outputs(tensors, i): assert len(tensors) == 6, len(tensors) @@ -262,18 +261,18 @@ class OnnxModel: def _update_states(self, states: List[np.ndarray]): self.states = states - def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + def __call__(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: A 3-D tensor of shape (N, T, C) Returns: - Return a 3-D tensor of shape (N, T', joiner_dim) where - T' is usually equal to ((T-7)//2+1)//2 + Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size) + where T' is usually equal to ((T-7)//2 - 3)//2 """ model_input, model_output_names = self._build_model_input_output(x) - out = self.encoder.run(model_output_names, model_input) + out = self.model.run(model_output_names, model_input) self._update_states(out[1:]) @@ -323,51 +322,24 @@ def create_streaming_feature_extractor() -> OnlineFeature: def greedy_search( - model: OnnxModel, - model_out: torch.Tensor, - context_size: int, - decoder_out: Optional[torch.Tensor] = None, - hyp: Optional[List[int]] = None, + log_probs: torch.Tensor, ) -> List[int]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + """Greedy search for a single utterance. Args: - model: - The transducer model. - model_out: - A 3-D tensor of shape (1, T, joiner_dim) - context_size: - The context size of the decoder model. - decoder_out: - Optional. Decoder output of the previous chunk. - hyp: - Decoding results for previous chunks. + log_probs: + A 3-D tensor of shape (1, T, vocab_size) Returns: - Return the decoded results so far. + Return the decoded result. """ + assert log_probs.ndim == 3, log_probs.shape + assert log_probs.shape[0] == 1, log_probs.shape + + max_indexes = log_probs[0].argmax(dim=1) + unique_indexes = torch.unique_consecutive(max_indexes) blank_id = 0 - - if decoder_out is None: - assert hyp is None, hyp - hyp = [blank_id] * context_size - decoder_input = torch.tensor([hyp], dtype=torch.int64) - decoder_out = model.run_decoder(decoder_input) - else: - assert hyp is not None, hyp - - model_out = model_out.squeeze(0) - T = model_out.size(0) - for t in range(T): - cur_encoder_out = model_out[t : t + 1] - joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) - y = joiner_out.argmax(dim=0).item() - if y != blank_id: - hyp.append(y) - decoder_input = hyp[-context_size:] - decoder_input = torch.tensor([decoder_input], dtype=torch.int64) - decoder_out = model.run_decoder(decoder_input) - - return hyp, decoder_out + unique_indexes = unique_indexes[unique_indexes != blank_id] + return unique_indexes.tolist() @torch.no_grad() @@ -376,11 +348,7 @@ def main(): args = parser.parse_args() logging.info(vars(args)) - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) + model = OnnxModel(model_filename=args.model_filename) sample_rate = 16000 @@ -400,9 +368,7 @@ def main(): segment = model.segment offset = model.offset - context_size = model.context_size - hyp = None - decoder_out = None + hyp = [] chunk = int(1 * sample_rate) # 1 second start = 0 @@ -423,18 +389,28 @@ def main(): num_processed_frames += offset frames = torch.cat(frames, dim=0) frames = frames.unsqueeze(0) - model_out = model.run_encoder(frames) - hyp = greedy_search( - model, - model_out, - hyp, - ) + log_probs = model(frames) - token_table = k2.SymbolTable.from_file(args.tokens) + hyp += greedy_search(log_probs) - text = "" - for i in hyp[context_size:]: - text += token_table[i] + # To handle byte-level BPE, we convert string tokens to utf-8 encoded bytes + id2token = {} + with open(args.tokens, encoding="utf-8") as f: + for line in f: + token, idx = line.split() + if token[:3] == "<0x" and token[-1] == ">": + token = int(token[1:-1], base=16) + assert 0 <= token < 256, token + token = token.to_bytes(1, byteorder="little") + else: + token = token.encode(encoding="utf-8") + + id2token[int(idx)] = token + + text = b"" + for i in hyp: + text += id2token[i] + text = text.decode(encoding="utf-8") text = text.replace("▁", " ").strip() logging.info(args.sound_file) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index e62491444..e7c4f40ee 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -326,7 +326,7 @@ class OnnxModel: A 3-D tensor of shape (N, T, C) Returns: Return a 3-D tensor of shape (N, T', joiner_dim) where - T' is usually equal to ((T-7)//2+1)//2 + T' is usually equal to ((T-7)//2-3)//2 """ encoder_input, encoder_output_names = self._build_encoder_input_output(x)