use token_table instead

This commit is contained in:
yaozengwei 2023-06-30 11:38:26 +08:00
parent 00623c45fd
commit a53d7102da

View File

@ -64,6 +64,7 @@ It will generate the following 3 files inside $repo/exp:
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
""" """
@ -73,7 +74,6 @@ import time
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import List, Tuple
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
@ -81,6 +81,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import greedy_search, OnnxModel from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats from icefall.utils import setup_logger, store_transcripts, write_error_stats
from k2 import SymbolTable
def get_parser(): def get_parser():
@ -117,10 +118,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", help="""Path to tokens.txt.""",
help="Path to the BPE model",
) )
parser.add_argument( parser.add_argument(
@ -134,7 +134,7 @@ def get_parser():
def decode_one_batch( def decode_one_batch(
model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict model: OnnxModel, token_table: SymbolTable, batch: dict
) -> List[List[str]]: ) -> List[List[str]]:
"""Decode one batch and return the result. """Decode one batch and return the result.
Currently it only greedy_search is supported. Currently it only greedy_search is supported.
@ -142,8 +142,8 @@ def decode_one_batch(
Args: Args:
model: model:
The neural model. The neural model.
sp: token_table:
The BPE model. The token table.
batch: batch:
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
@ -165,14 +165,20 @@ def decode_one_batch(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
) )
hyps = [sp.decode(h).split() for h in hyps] def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
hyps = [token_ids_to_words(h).split() for h in hyps]
return hyps return hyps
def decode_dataset( def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, token_table: SymbolTable,
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: ) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
"""Decode dataset. """Decode dataset.
@ -181,8 +187,8 @@ def decode_dataset(
PyTorch's dataloader containing the dataset to decode. PyTorch's dataloader containing the dataset to decode.
model: model:
The neural model. The neural model.
sp: token_table:
The BPE model. The token table.
Returns: Returns:
- A list of tuples. Each tuple contains three elements: - A list of tuples. Each tuple contains three elements:
@ -207,7 +213,7 @@ def decode_dataset(
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
hyps = decode_one_batch(model=model, sp=sp, batch=batch) hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
@ -271,11 +277,7 @@ def main():
device = torch.device("cpu") device = torch.device("cpu")
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor() token_table = SymbolTable.from_file(args.tokens)
sp.load(args.bpe_model)
blank_id = sp.piece_to_id("<blk>")
assert blank_id == 0, blank_id
logging.info(vars(args)) logging.info(vars(args))
@ -301,7 +303,7 @@ def main():
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time() start_time = time.time()
results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp) results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table)
end_time = time.time() end_time = time.time()
elapsed_seconds = end_time - start_time elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration rtf = elapsed_seconds / total_duration