moved num_tokens to utils.py

moved `num_tokens` to `icefall/utils.py` to reduce code redundancy
This commit is contained in:
jinzr 2023-07-06 12:41:29 +08:00
parent e3ec8932e5
commit f5257b1528
5 changed files with 23 additions and 84 deletions

View File

@ -100,27 +100,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import str2bool, num_tokens
def num_tokens(
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
) -> int:
"""Return the number of tokens excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = token_table.symbols
ans = []
for s in symbols:
if not disambig_pattern.match(s):
ans.append(token_table[s])
num_tokens = len(ans)
if 0 in ans:
num_tokens -= 1
return num_tokens
def get_parser(): def get_parser():

View File

@ -68,7 +68,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import setup_logger, str2bool from icefall.utils import setup_logger, str2bool, num_tokens
def get_parser(): def get_parser():
@ -160,26 +160,6 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
onnx.save(model, filename) onnx.save(model, filename)
def num_tokens(
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
) -> int:
"""Return the number of tokens excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = token_table.symbols
ans = []
for s in symbols:
if not disambig_pattern.match(s):
ans.append(token_table[s])
num_tokens = len(ans)
if 0 in ans:
num_tokens -= 1
return num_tokens
class OnnxEncoder(nn.Module): class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner""" """A wrapper for Zipformer and the encoder_proj from the joiner"""

View File

@ -100,27 +100,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import str2bool, num_tokens
def num_tokens(
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
) -> int:
"""Return the number of tokens excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = token_table.symbols
ans = []
for s in symbols:
if not disambig_pattern.match(s):
ans.append(token_table[s])
num_tokens = len(ans)
if 0 in ans:
num_tokens -= 1
return num_tokens
def get_parser(): def get_parser():

View File

@ -176,27 +176,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import make_pad_mask, str2bool from icefall.utils import make_pad_mask, str2bool, num_tokens
def num_tokens(
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
) -> int:
"""Return the number of tokens excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = token_table.symbols
ans = []
for s in symbols:
if not disambig_pattern.match(s):
ans.append(token_table[s])
num_tokens = len(ans)
if 0 in ans:
num_tokens -= 1
return num_tokens
def get_parser(): def get_parser():

View File

@ -1899,3 +1899,22 @@ def symlink_or_copy(exp_dir: Path, src: str, dst: str):
except OSError: except OSError:
copyfile(src=exp_dir / src, dst=exp_dir / dst) copyfile(src=exp_dir / src, dst=exp_dir / dst)
os.close(dir_fd) os.close(dir_fd)
def num_tokens(
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
) -> int:
"""Return the number of tokens excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = token_table.symbols
ans = []
for s in symbols:
if not disambig_pattern.match(s):
ans.append(token_table[s])
num_tokens = len(ans)
if 0 in ans:
num_tokens -= 1
return num_tokens