mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 00:54:18 +00:00
use token_table instead
This commit is contained in:
parent
00623c45fd
commit
a53d7102da
@ -64,6 +64,7 @@ It will generate the following 3 files inside $repo/exp:
|
|||||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -73,7 +74,6 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
@ -81,6 +81,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
|||||||
from onnx_pretrained import greedy_search, OnnxModel
|
from onnx_pretrained import greedy_search, OnnxModel
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
from k2 import SymbolTable
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -117,10 +118,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
help="""Path to tokens.txt.""",
|
||||||
help="Path to the BPE model",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -134,7 +134,7 @@ def get_parser():
|
|||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict
|
model: OnnxModel, token_table: SymbolTable, batch: dict
|
||||||
) -> List[List[str]]:
|
) -> List[List[str]]:
|
||||||
"""Decode one batch and return the result.
|
"""Decode one batch and return the result.
|
||||||
Currently it only greedy_search is supported.
|
Currently it only greedy_search is supported.
|
||||||
@ -142,8 +142,8 @@ def decode_one_batch(
|
|||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
token_table:
|
||||||
The BPE model.
|
The token table.
|
||||||
batch:
|
batch:
|
||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
@ -165,14 +165,20 @@ def decode_one_batch(
|
|||||||
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
|
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
hyps = [sp.decode(h).split() for h in hyps]
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
|
hyps = [token_ids_to_words(h).split() for h in hyps]
|
||||||
return hyps
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
token_table: SymbolTable,
|
||||||
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
|
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -181,8 +187,8 @@ def decode_dataset(
|
|||||||
PyTorch's dataloader containing the dataset to decode.
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
token_table:
|
||||||
The BPE model.
|
The token table.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- A list of tuples. Each tuple contains three elements:
|
- A list of tuples. Each tuple contains three elements:
|
||||||
@ -207,7 +213,7 @@ def decode_dataset(
|
|||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
|
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
|
||||||
|
|
||||||
hyps = decode_one_batch(model=model, sp=sp, batch=batch)
|
hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
|
||||||
|
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
@ -271,11 +277,7 @@ def main():
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
token_table = SymbolTable.from_file(args.tokens)
|
||||||
sp.load(args.bpe_model)
|
|
||||||
|
|
||||||
blank_id = sp.piece_to_id("<blk>")
|
|
||||||
assert blank_id == 0, blank_id
|
|
||||||
|
|
||||||
logging.info(vars(args))
|
logging.info(vars(args))
|
||||||
|
|
||||||
@ -301,7 +303,7 @@ def main():
|
|||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp)
|
results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_seconds = end_time - start_time
|
elapsed_seconds = end_time - start_time
|
||||||
rtf = elapsed_seconds / total_duration
|
rtf = elapsed_seconds / total_duration
|
||||||
|
Loading…
x
Reference in New Issue
Block a user