mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
moved num_tokens
to utils.py
moved `num_tokens` to `icefall/utils.py` to reduce code redundancy
This commit is contained in:
parent
06cb1346ac
commit
ccb6031853
@ -100,27 +100,7 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def num_tokens(
|
||||
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
||||
) -> int:
|
||||
"""Return the number of tokens excluding those from
|
||||
disambiguation symbols.
|
||||
|
||||
Caution:
|
||||
0 is not a token ID so it is excluded from the return value.
|
||||
"""
|
||||
symbols = token_table.symbols
|
||||
ans = []
|
||||
for s in symbols:
|
||||
if not disambig_pattern.match(s):
|
||||
ans.append(token_table[s])
|
||||
num_tokens = len(ans)
|
||||
if 0 in ans:
|
||||
num_tokens -= 1
|
||||
return num_tokens
|
||||
from icefall.utils import str2bool, num_tokens
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -68,7 +68,7 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import setup_logger, str2bool
|
||||
from icefall.utils import setup_logger, str2bool, num_tokens
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -160,26 +160,6 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def num_tokens(
|
||||
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
||||
) -> int:
|
||||
"""Return the number of tokens excluding those from
|
||||
disambiguation symbols.
|
||||
|
||||
Caution:
|
||||
0 is not a token ID so it is excluded from the return value.
|
||||
"""
|
||||
symbols = token_table.symbols
|
||||
ans = []
|
||||
for s in symbols:
|
||||
if not disambig_pattern.match(s):
|
||||
ans.append(token_table[s])
|
||||
num_tokens = len(ans)
|
||||
if 0 in ans:
|
||||
num_tokens -= 1
|
||||
return num_tokens
|
||||
|
||||
|
||||
class OnnxEncoder(nn.Module):
|
||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||
|
||||
|
@ -100,27 +100,7 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def num_tokens(
|
||||
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
||||
) -> int:
|
||||
"""Return the number of tokens excluding those from
|
||||
disambiguation symbols.
|
||||
|
||||
Caution:
|
||||
0 is not a token ID so it is excluded from the return value.
|
||||
"""
|
||||
symbols = token_table.symbols
|
||||
ans = []
|
||||
for s in symbols:
|
||||
if not disambig_pattern.match(s):
|
||||
ans.append(token_table[s])
|
||||
num_tokens = len(ans)
|
||||
if 0 in ans:
|
||||
num_tokens -= 1
|
||||
return num_tokens
|
||||
from icefall.utils import str2bool, num_tokens
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -176,27 +176,7 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import make_pad_mask, str2bool
|
||||
|
||||
|
||||
def num_tokens(
|
||||
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
||||
) -> int:
|
||||
"""Return the number of tokens excluding those from
|
||||
disambiguation symbols.
|
||||
|
||||
Caution:
|
||||
0 is not a token ID so it is excluded from the return value.
|
||||
"""
|
||||
symbols = token_table.symbols
|
||||
ans = []
|
||||
for s in symbols:
|
||||
if not disambig_pattern.match(s):
|
||||
ans.append(token_table[s])
|
||||
num_tokens = len(ans)
|
||||
if 0 in ans:
|
||||
num_tokens -= 1
|
||||
return num_tokens
|
||||
from icefall.utils import make_pad_mask, str2bool, num_tokens
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -2060,3 +2060,22 @@ def symlink_or_copy(exp_dir: Path, src: str, dst: str):
|
||||
except OSError:
|
||||
copyfile(src=exp_dir / src, dst=exp_dir / dst)
|
||||
os.close(dir_fd)
|
||||
|
||||
def num_tokens(
|
||||
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
||||
) -> int:
|
||||
"""Return the number of tokens excluding those from
|
||||
disambiguation symbols.
|
||||
|
||||
Caution:
|
||||
0 is not a token ID so it is excluded from the return value.
|
||||
"""
|
||||
symbols = token_table.symbols
|
||||
ans = []
|
||||
for s in symbols:
|
||||
if not disambig_pattern.match(s):
|
||||
ans.append(token_table[s])
|
||||
num_tokens = len(ans)
|
||||
if 0 in ans:
|
||||
num_tokens -= 1
|
||||
return num_tokens
|
Loading…
x
Reference in New Issue
Block a user