From a53d7102da19d00d433e8277d1bd3448d8c2b805 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 30 Jun 2023 11:38:26 +0800 Subject: [PATCH] use token_table instead --- egs/librispeech/ASR/zipformer/onnx_decode.py | 38 ++++++++++---------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/onnx_decode.py b/egs/librispeech/ASR/zipformer/onnx_decode.py index 2e7fdca9d..2aca36ca9 100755 --- a/egs/librispeech/ASR/zipformer/onnx_decode.py +++ b/egs/librispeech/ASR/zipformer/onnx_decode.py @@ -64,6 +64,7 @@ It will generate the following 3 files inside $repo/exp: --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ """ @@ -73,7 +74,6 @@ import time from pathlib import Path from typing import List, Tuple -import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule @@ -81,6 +81,7 @@ from asr_datamodule import LibriSpeechAsrDataModule from onnx_pretrained import greedy_search, OnnxModel from icefall.utils import setup_logger, store_transcripts, write_error_stats +from k2 import SymbolTable def get_parser(): @@ -117,10 +118,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -134,7 +134,7 @@ def get_parser(): def decode_one_batch( - model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict + model: OnnxModel, token_table: SymbolTable, batch: dict ) -> List[List[str]]: """Decode one batch and return the result. Currently it only greedy_search is supported. @@ -142,8 +142,8 @@ def decode_one_batch( Args: model: The neural model. - sp: - The BPE model. + token_table: + The token table. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation @@ -165,14 +165,20 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens ) - hyps = [sp.decode(h).split() for h in hyps] + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + hyps = [token_ids_to_words(h).split() for h in hyps] return hyps def decode_dataset( dl: torch.utils.data.DataLoader, model: nn.Module, - sp: spm.SentencePieceProcessor, + token_table: SymbolTable, ) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: """Decode dataset. @@ -181,8 +187,8 @@ def decode_dataset( PyTorch's dataloader containing the dataset to decode. model: The neural model. - sp: - The BPE model. + token_table: + The token table. Returns: - A list of tuples. Each tuple contains three elements: @@ -207,7 +213,7 @@ def decode_dataset( cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) - hyps = decode_one_batch(model=model, sp=sp, batch=batch) + hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) this_batch = [] assert len(hyps) == len(texts) @@ -271,11 +277,7 @@ def main(): device = torch.device("cpu") logging.info(f"Device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - blank_id = sp.piece_to_id("") - assert blank_id == 0, blank_id + token_table = SymbolTable.from_file(args.tokens) logging.info(vars(args)) @@ -301,7 +303,7 @@ def main(): for test_set, test_dl in zip(test_sets, test_dl): start_time = time.time() - results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp) + results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table) end_time = time.time() elapsed_seconds = end_time - start_time rtf = elapsed_seconds / total_duration