use token_table instead

This commit is contained in:
yaozengwei 2023-06-30 11:38:26 +08:00
parent 00623c45fd
commit a53d7102da

View File

@ -64,6 +64,7 @@ It will generate the following 3 files inside $repo/exp:
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
"""
@ -73,7 +74,6 @@ import time
from pathlib import Path
from typing import List, Tuple
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
@ -81,6 +81,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats
from k2 import SymbolTable
def get_parser():
@ -117,10 +118,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -134,7 +134,7 @@ def get_parser():
def decode_one_batch(
model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict
model: OnnxModel, token_table: SymbolTable, batch: dict
) -> List[List[str]]:
"""Decode one batch and return the result.
Currently it only greedy_search is supported.
@ -142,8 +142,8 @@ def decode_one_batch(
Args:
model:
The neural model.
sp:
The BPE model.
token_table:
The token table.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
@ -165,14 +165,20 @@ def decode_one_batch(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
)
hyps = [sp.decode(h).split() for h in hyps]
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
hyps = [token_ids_to_words(h).split() for h in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
model: nn.Module,
sp: spm.SentencePieceProcessor,
token_table: SymbolTable,
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
"""Decode dataset.
@ -181,8 +187,8 @@ def decode_dataset(
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
sp:
The BPE model.
token_table:
The token table.
Returns:
- A list of tuples. Each tuple contains three elements:
@ -207,7 +213,7 @@ def decode_dataset(
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
hyps = decode_one_batch(model=model, sp=sp, batch=batch)
hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
this_batch = []
assert len(hyps) == len(texts)
@ -271,11 +277,7 @@ def main():
device = torch.device("cpu")
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
blank_id = sp.piece_to_id("<blk>")
assert blank_id == 0, blank_id
token_table = SymbolTable.from_file(args.tokens)
logging.info(vars(args))
@ -301,7 +303,7 @@ def main():
for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time()
results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp)
results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table)
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration