#!/usr/bin/env python3 # Copyright 2021 Xiaomi Corporation (Author: Guo Liyong) # Apache 2.0 from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Iterable, List, Optional, Union from pathlib import Path import numpy as np import sentencepiece as spm class PreProcessor(ABC): @abstractmethod def __call__(self, text: str) -> List[int]: raise NotImplementedError class Numericalizer(PreProcessor): def __init__(self, tokenizer_file, token_list, unk_symbol=""): self.tokenizer_file = tokenizer_file self.token_list = token_list self._token2idx = None self._tokenizer = None self._assign_special_symbols() def _assign_special_symbols(self): # and share same index for model download from espnet model zoo assert "" in self.token2idx or ( "" in self.token2idx and "" in self.tokenid ) assert "" in self.token2idx self.sos_idx = ( self.token2idx[""] if "" in self.token2idx else self.token2idx[""] ) self.eos_idx = ( self.token2idx[""] if "" in self.token2idx else self.token2idx[""] ) self.unk_idx = self.token2idx[""] @property def tokenizer(self): if self._tokenizer is None: sp = spm.SentencePieceProcessor() sp.Load(self.tokenizer_file) self._tokenizer = sp return self._tokenizer def text2tokens(self, line: str) -> List[str]: return self.tokenizer.EncodeAsPieces(line) def tokens2text(self, tokens: Iterable[str]) -> str: return self.tokenizer.DecodePieces(list(tokens)) @property def token2idx(self): if self._token2idx is None: self._token2idx = {} for idx, token in enumerate(self.token_list): if token in self._token2idx: raise RuntimeError(f'Symbol "{token}" is duplicated') self._token2idx[token] = idx return self._token2idx def ids2tokens( self, integers: Union[np.ndarray, Iterable[int]] ) -> List[str]: if isinstance(integers, np.ndarray) and integers.ndim != 1: raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}") return [self.token_list[i] for i in integers] def __call__(self, text: str) -> List[int]: tokens = self.text2tokens(text) token_idxs = ( [self.sos_idx] + [self.token2idx.get(token, self.unk_idx) for token in tokens] + [self.eos_idx] ) return token_idxs