import argparse from pathlib import Path from typing import Callable, List, Union import sentencepiece as spm from k2 import SymbolTable class Tokenizer: text2word: Callable[[str], List[str]] @staticmethod def add_arguments(parser: argparse.ArgumentParser): group = parser.add_argument_group(title="Lang related options") group.add_argument("--lang", type=Path, help="Path to lang directory.") group.add_argument( "--lang-type", type=str, default=None, help=( "Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. " "Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor" ), ) @staticmethod def Load(lang_dir: Path, lang_type="", oov=""): if not lang_type: assert (lang_dir / "lang_type").exists(), "lang_type not specified." lang_type = (lang_dir / "lang_type").read_text().strip() tokenizer = None if lang_type == "bpe": assert ( lang_dir / "bpe.model" ).exists(), f"No BPE .model could be found in {lang_dir}." tokenizer = spm.SentencePieceProcessor() tokenizer.Load(str(lang_dir / "bpe.model")) elif lang_type == "char": tokenizer = CharTokenizer(lang_dir, oov=oov) else: raise NotImplementedError(f"{lang_type} not supported at the moment.") return tokenizer load = Load def PieceToId(self, piece: str) -> int: raise NotImplementedError( "You need to implement this function in the child class." ) piece_to_id = PieceToId def IdToPiece(self, id: int) -> str: raise NotImplementedError( "You need to implement this function in the child class." ) id_to_piece = IdToPiece def GetPieceSize(self) -> int: raise NotImplementedError( "You need to implement this function in the child class." ) get_piece_size = GetPieceSize def __len__(self) -> int: return self.get_piece_size() def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: raise NotImplementedError( "You need to implement this function in the child class." ) def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: raise NotImplementedError( "You need to implement this function in the child class." ) def EncodeAsIds(self, input: str) -> List[int]: return self.EncodeAsIdsBatch([input])[0] def EncodeAsPieces(self, input: str) -> List[str]: return self.EncodeAsPiecesBatch([input])[0] def Encode( self, input: Union[str, List[str]], out_type=int ) -> Union[List, List[List]]: if not input: return [] if isinstance(input, list): if out_type is int: return self.EncodeAsIdsBatch(input) if out_type is str: return self.EncodeAsPiecesBatch(input) if out_type is int: return self.EncodeAsIds(input) if out_type is str: return self.EncodeAsPieces(input) encode = Encode def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: raise NotImplementedError( "You need to implement this function in the child class." ) def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: raise NotImplementedError( "You need to implement this function in the child class." ) def DecodeIds(self, input: List[int]) -> str: return self.DecodeIdsBatch([input])[0] def DecodePieces(self, input: List[str]) -> str: return self.DecodePiecesBatch([input])[0] def Decode( self, input: Union[int, List[int], List[str], List[List[int]], List[List[str]]], ) -> Union[List[str], str]: if not input: return "" if isinstance(input, int): return self.id_to_piece(input) elif isinstance(input, str): raise TypeError( "Unlike spm.SentencePieceProcessor, cannot decode from type str." ) if isinstance(input[0], list): if not input[0] or isinstance(input[0][0], int): return self.DecodeIdsBatch(input) if isinstance(input[0][0], str): return self.DecodePiecesBatch(input) if isinstance(input[0], int): return self.DecodeIds(input) if isinstance(input[0], str): return self.DecodePieces(input) raise RuntimeError("Unknown input type") decode = Decode def SplitBatch(self, input: List[str]) -> List[List[str]]: raise NotImplementedError( "You need to implement this function in the child class." ) def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]: if isinstance(input, list): return self.SplitBatch(input) elif isinstance(input, str): return self.SplitBatch([input])[0] raise RuntimeError("Unknown input type") split = Split class CharTokenizer(Tokenizer): def __init__(self, lang_dir: Path, oov="", sep=""): assert ( lang_dir / "tokens.txt" ).exists(), f"tokens.txt could not be found in {lang_dir}." token_table = SymbolTable.from_file(lang_dir / "tokens.txt") assert ( "#0" not in token_table ), "This tokenizer does not support disambig symbols." self._id2sym = token_table._id2sym self._sym2id = token_table._sym2id self.oov = oov self.oov_id = self._sym2id[oov] self.sep = sep if self.sep: self.text2word = lambda x: x.split(self.sep) else: self.text2word = lambda x: list(x.replace(" ", "")) def piece_to_id(self, piece: str) -> int: try: return self._sym2id[piece] except KeyError: return self.oov_id def id_to_piece(self, id: int) -> str: return self._id2sym[id] def get_piece_size(self) -> int: return len(self._sym2id) def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input] def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: return [ [i if i in self._sym2id else self.oov for i in self.text2word(text)] for text in input ] def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: return [self.sep.join(self.id_to_piece(i) for i in text) for text in input] def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: return [self.sep.join(text) for text in input] def SplitBatch(self, input: List[str]) -> List[List[str]]: return [self.text2word(text) for text in input] def test_CharTokenizer(): test_single_string = "こんにちは" test_multiple_string = [ "今日はいい天気ですよね", "諏訪湖は綺麗でしょう", "这在词表外", "分かち 書き に し た 文章 です", "", ] test_empty_string = "" sp = Tokenizer.load(Path("lang_char"), "char", oov="") splitter = sp.split print(sp.encode(test_single_string, out_type=str)) print(sp.encode(test_single_string, out_type=int)) print(sp.encode(test_multiple_string, out_type=str)) print(sp.encode(test_multiple_string, out_type=int)) print(sp.encode(test_empty_string, out_type=str)) print(sp.encode(test_empty_string, out_type=int)) print(sp.decode(sp.encode(test_single_string, out_type=str))) print(sp.decode(sp.encode(test_single_string, out_type=int))) print(sp.decode(sp.encode(test_multiple_string, out_type=str))) print(sp.decode(sp.encode(test_multiple_string, out_type=int))) print(sp.decode(sp.encode(test_empty_string, out_type=str))) print(sp.decode(sp.encode(test_empty_string, out_type=int))) print(splitter(test_single_string)) print(splitter(test_multiple_string)) print(splitter(test_empty_string)) if __name__ == "__main__": test_CharTokenizer()