mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
* add streaming support to reazonresearch * update README for streaming * Update RESULTS.md * add onnx decode --------- Co-authored-by: root <root@KDA03.cm.cluster> Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> Co-authored-by: root <root@KDA01.cm.cluster> Co-authored-by: zr_jin <peter.jin.cn@gmail.com>
253 lines
8.2 KiB
Python
253 lines
8.2 KiB
Python
import argparse
|
|
from pathlib import Path
|
|
from typing import Callable, List, Union
|
|
|
|
import sentencepiece as spm
|
|
from k2 import SymbolTable
|
|
|
|
|
|
class Tokenizer:
|
|
text2word: Callable[[str], List[str]]
|
|
|
|
@staticmethod
|
|
def add_arguments(parser: argparse.ArgumentParser):
|
|
group = parser.add_argument_group(title="Lang related options")
|
|
group.add_argument("--lang", type=Path, help="Path to lang directory.")
|
|
|
|
group.add_argument(
|
|
"--lang-type",
|
|
type=str,
|
|
default=None,
|
|
help=(
|
|
"Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. "
|
|
"Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor"
|
|
),
|
|
)
|
|
|
|
@staticmethod
|
|
def Load(lang_dir: Path, lang_type="", oov="<unk>"):
|
|
|
|
if not lang_type:
|
|
assert (lang_dir / "lang_type").exists(), "lang_type not specified."
|
|
lang_type = (lang_dir / "lang_type").read_text().strip()
|
|
|
|
tokenizer = None
|
|
|
|
if lang_type == "bpe":
|
|
assert (
|
|
lang_dir / "bpe.model"
|
|
).exists(), f"No BPE .model could be found in {lang_dir}."
|
|
tokenizer = spm.SentencePieceProcessor()
|
|
tokenizer.Load(str(lang_dir / "bpe.model"))
|
|
elif lang_type == "char":
|
|
tokenizer = CharTokenizer(lang_dir, oov=oov)
|
|
else:
|
|
raise NotImplementedError(f"{lang_type} not supported at the moment.")
|
|
|
|
return tokenizer
|
|
|
|
load = Load
|
|
|
|
def PieceToId(self, piece: str) -> int:
|
|
raise NotImplementedError(
|
|
"You need to implement this function in the child class."
|
|
)
|
|
|
|
piece_to_id = PieceToId
|
|
|
|
def IdToPiece(self, id: int) -> str:
|
|
raise NotImplementedError(
|
|
"You need to implement this function in the child class."
|
|
)
|
|
|
|
id_to_piece = IdToPiece
|
|
|
|
def GetPieceSize(self) -> int:
|
|
raise NotImplementedError(
|
|
"You need to implement this function in the child class."
|
|
)
|
|
|
|
get_piece_size = GetPieceSize
|
|
|
|
def __len__(self) -> int:
|
|
return self.get_piece_size()
|
|
|
|
def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]:
|
|
raise NotImplementedError(
|
|
"You need to implement this function in the child class."
|
|
)
|
|
|
|
def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]:
|
|
raise NotImplementedError(
|
|
"You need to implement this function in the child class."
|
|
)
|
|
|
|
def EncodeAsIds(self, input: str) -> List[int]:
|
|
return self.EncodeAsIdsBatch([input])[0]
|
|
|
|
def EncodeAsPieces(self, input: str) -> List[str]:
|
|
return self.EncodeAsPiecesBatch([input])[0]
|
|
|
|
def Encode(
|
|
self, input: Union[str, List[str]], out_type=int
|
|
) -> Union[List, List[List]]:
|
|
if not input:
|
|
return []
|
|
|
|
if isinstance(input, list):
|
|
if out_type is int:
|
|
return self.EncodeAsIdsBatch(input)
|
|
if out_type is str:
|
|
return self.EncodeAsPiecesBatch(input)
|
|
|
|
if out_type is int:
|
|
return self.EncodeAsIds(input)
|
|
if out_type is str:
|
|
return self.EncodeAsPieces(input)
|
|
|
|
encode = Encode
|
|
|
|
def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]:
|
|
raise NotImplementedError(
|
|
"You need to implement this function in the child class."
|
|
)
|
|
|
|
def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]:
|
|
raise NotImplementedError(
|
|
"You need to implement this function in the child class."
|
|
)
|
|
|
|
def DecodeIds(self, input: List[int]) -> str:
|
|
return self.DecodeIdsBatch([input])[0]
|
|
|
|
def DecodePieces(self, input: List[str]) -> str:
|
|
return self.DecodePiecesBatch([input])[0]
|
|
|
|
def Decode(
|
|
self,
|
|
input: Union[int, List[int], List[str], List[List[int]], List[List[str]]],
|
|
) -> Union[List[str], str]:
|
|
|
|
if not input:
|
|
return ""
|
|
|
|
if isinstance(input, int):
|
|
return self.id_to_piece(input)
|
|
elif isinstance(input, str):
|
|
raise TypeError(
|
|
"Unlike spm.SentencePieceProcessor, cannot decode from type str."
|
|
)
|
|
|
|
if isinstance(input[0], list):
|
|
if not input[0] or isinstance(input[0][0], int):
|
|
return self.DecodeIdsBatch(input)
|
|
|
|
if isinstance(input[0][0], str):
|
|
return self.DecodePiecesBatch(input)
|
|
|
|
if isinstance(input[0], int):
|
|
return self.DecodeIds(input)
|
|
if isinstance(input[0], str):
|
|
return self.DecodePieces(input)
|
|
|
|
raise RuntimeError("Unknown input type")
|
|
|
|
decode = Decode
|
|
|
|
def SplitBatch(self, input: List[str]) -> List[List[str]]:
|
|
raise NotImplementedError(
|
|
"You need to implement this function in the child class."
|
|
)
|
|
|
|
def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]:
|
|
if isinstance(input, list):
|
|
return self.SplitBatch(input)
|
|
elif isinstance(input, str):
|
|
return self.SplitBatch([input])[0]
|
|
raise RuntimeError("Unknown input type")
|
|
|
|
split = Split
|
|
|
|
|
|
class CharTokenizer(Tokenizer):
|
|
def __init__(self, lang_dir: Path, oov="<unk>", sep=""):
|
|
assert (
|
|
lang_dir / "tokens.txt"
|
|
).exists(), f"tokens.txt could not be found in {lang_dir}."
|
|
token_table = SymbolTable.from_file(lang_dir / "tokens.txt")
|
|
assert (
|
|
"#0" not in token_table
|
|
), "This tokenizer does not support disambig symbols."
|
|
self._id2sym = token_table._id2sym
|
|
self._sym2id = token_table._sym2id
|
|
self.oov = oov
|
|
self.oov_id = self._sym2id[oov]
|
|
self.sep = sep
|
|
if self.sep:
|
|
self.text2word = lambda x: x.split(self.sep)
|
|
else:
|
|
self.text2word = lambda x: list(x.replace(" ", ""))
|
|
|
|
def piece_to_id(self, piece: str) -> int:
|
|
try:
|
|
return self._sym2id[piece]
|
|
except KeyError:
|
|
return self.oov_id
|
|
|
|
def id_to_piece(self, id: int) -> str:
|
|
return self._id2sym[id]
|
|
|
|
def get_piece_size(self) -> int:
|
|
return len(self._sym2id)
|
|
|
|
def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]:
|
|
return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input]
|
|
|
|
def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]:
|
|
return [
|
|
[i if i in self._sym2id else self.oov for i in self.text2word(text)]
|
|
for text in input
|
|
]
|
|
|
|
def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]:
|
|
return [self.sep.join(self.id_to_piece(i) for i in text) for text in input]
|
|
|
|
def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]:
|
|
return [self.sep.join(text) for text in input]
|
|
|
|
def SplitBatch(self, input: List[str]) -> List[List[str]]:
|
|
return [self.text2word(text) for text in input]
|
|
|
|
|
|
def test_CharTokenizer():
|
|
test_single_string = "こんにちは"
|
|
test_multiple_string = [
|
|
"今日はいい天気ですよね",
|
|
"諏訪湖は綺麗でしょう",
|
|
"这在词表外",
|
|
"分かち 書き に し た 文章 です",
|
|
"",
|
|
]
|
|
test_empty_string = ""
|
|
sp = Tokenizer.load(Path("lang_char"), "char", oov="<unk>")
|
|
splitter = sp.split
|
|
print(sp.encode(test_single_string, out_type=str))
|
|
print(sp.encode(test_single_string, out_type=int))
|
|
print(sp.encode(test_multiple_string, out_type=str))
|
|
print(sp.encode(test_multiple_string, out_type=int))
|
|
print(sp.encode(test_empty_string, out_type=str))
|
|
print(sp.encode(test_empty_string, out_type=int))
|
|
print(sp.decode(sp.encode(test_single_string, out_type=str)))
|
|
print(sp.decode(sp.encode(test_single_string, out_type=int)))
|
|
print(sp.decode(sp.encode(test_multiple_string, out_type=str)))
|
|
print(sp.decode(sp.encode(test_multiple_string, out_type=int)))
|
|
print(sp.decode(sp.encode(test_empty_string, out_type=str)))
|
|
print(sp.decode(sp.encode(test_empty_string, out_type=int)))
|
|
print(splitter(test_single_string))
|
|
print(splitter(test_multiple_string))
|
|
print(splitter(test_empty_string))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_CharTokenizer()
|