mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
use token_table instead
This commit is contained in:
parent
00623c45fd
commit
a53d7102da
@ -64,6 +64,7 @@ It will generate the following 3 files inside $repo/exp:
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
"""
|
||||
|
||||
|
||||
@ -73,7 +74,6 @@ import time
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
@ -81,6 +81,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
from k2 import SymbolTable
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -117,10 +118,9 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -134,7 +134,7 @@ def get_parser():
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict
|
||||
model: OnnxModel, token_table: SymbolTable, batch: dict
|
||||
) -> List[List[str]]:
|
||||
"""Decode one batch and return the result.
|
||||
Currently it only greedy_search is supported.
|
||||
@ -142,8 +142,8 @@ def decode_one_batch(
|
||||
Args:
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
token_table:
|
||||
The token table.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
@ -165,14 +165,20 @@ def decode_one_batch(
|
||||
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
|
||||
)
|
||||
|
||||
hyps = [sp.decode(h).split() for h in hyps]
|
||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||
text = ""
|
||||
for i in token_ids:
|
||||
text += token_table[i]
|
||||
return text.replace("▁", " ").strip()
|
||||
|
||||
hyps = [token_ids_to_words(h).split() for h in hyps]
|
||||
return hyps
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
token_table: SymbolTable,
|
||||
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -181,8 +187,8 @@ def decode_dataset(
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
token_table:
|
||||
The token table.
|
||||
|
||||
Returns:
|
||||
- A list of tuples. Each tuple contains three elements:
|
||||
@ -207,7 +213,7 @@ def decode_dataset(
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
|
||||
|
||||
hyps = decode_one_batch(model=model, sp=sp, batch=batch)
|
||||
hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
|
||||
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
@ -271,11 +277,7 @@ def main():
|
||||
device = torch.device("cpu")
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
blank_id = sp.piece_to_id("<blk>")
|
||||
assert blank_id == 0, blank_id
|
||||
token_table = SymbolTable.from_file(args.tokens)
|
||||
|
||||
logging.info(vars(args))
|
||||
|
||||
@ -301,7 +303,7 @@ def main():
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
start_time = time.time()
|
||||
results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp)
|
||||
results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table)
|
||||
end_time = time.time()
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / total_duration
|
||||
|
Loading…
x
Reference in New Issue
Block a user