mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
moved num_tokens
to utils.py
moved `num_tokens` to `icefall/utils.py` to reduce code redundancy
This commit is contained in:
parent
e3ec8932e5
commit
f5257b1528
@ -100,27 +100,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool, num_tokens
|
||||||
|
|
||||||
|
|
||||||
def num_tokens(
|
|
||||||
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
|
||||||
) -> int:
|
|
||||||
"""Return the number of tokens excluding those from
|
|
||||||
disambiguation symbols.
|
|
||||||
|
|
||||||
Caution:
|
|
||||||
0 is not a token ID so it is excluded from the return value.
|
|
||||||
"""
|
|
||||||
symbols = token_table.symbols
|
|
||||||
ans = []
|
|
||||||
for s in symbols:
|
|
||||||
if not disambig_pattern.match(s):
|
|
||||||
ans.append(token_table[s])
|
|
||||||
num_tokens = len(ans)
|
|
||||||
if 0 in ans:
|
|
||||||
num_tokens -= 1
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -68,7 +68,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import setup_logger, str2bool, num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -160,26 +160,6 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
|||||||
onnx.save(model, filename)
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
def num_tokens(
|
|
||||||
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
|
||||||
) -> int:
|
|
||||||
"""Return the number of tokens excluding those from
|
|
||||||
disambiguation symbols.
|
|
||||||
|
|
||||||
Caution:
|
|
||||||
0 is not a token ID so it is excluded from the return value.
|
|
||||||
"""
|
|
||||||
symbols = token_table.symbols
|
|
||||||
ans = []
|
|
||||||
for s in symbols:
|
|
||||||
if not disambig_pattern.match(s):
|
|
||||||
ans.append(token_table[s])
|
|
||||||
num_tokens = len(ans)
|
|
||||||
if 0 in ans:
|
|
||||||
num_tokens -= 1
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class OnnxEncoder(nn.Module):
|
class OnnxEncoder(nn.Module):
|
||||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||||
|
|
||||||
|
@ -100,27 +100,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool, num_tokens
|
||||||
|
|
||||||
|
|
||||||
def num_tokens(
|
|
||||||
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
|
||||||
) -> int:
|
|
||||||
"""Return the number of tokens excluding those from
|
|
||||||
disambiguation symbols.
|
|
||||||
|
|
||||||
Caution:
|
|
||||||
0 is not a token ID so it is excluded from the return value.
|
|
||||||
"""
|
|
||||||
symbols = token_table.symbols
|
|
||||||
ans = []
|
|
||||||
for s in symbols:
|
|
||||||
if not disambig_pattern.match(s):
|
|
||||||
ans.append(token_table[s])
|
|
||||||
num_tokens = len(ans)
|
|
||||||
if 0 in ans:
|
|
||||||
num_tokens -= 1
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -176,27 +176,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import make_pad_mask, str2bool
|
from icefall.utils import make_pad_mask, str2bool, num_tokens
|
||||||
|
|
||||||
|
|
||||||
def num_tokens(
|
|
||||||
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
|
||||||
) -> int:
|
|
||||||
"""Return the number of tokens excluding those from
|
|
||||||
disambiguation symbols.
|
|
||||||
|
|
||||||
Caution:
|
|
||||||
0 is not a token ID so it is excluded from the return value.
|
|
||||||
"""
|
|
||||||
symbols = token_table.symbols
|
|
||||||
ans = []
|
|
||||||
for s in symbols:
|
|
||||||
if not disambig_pattern.match(s):
|
|
||||||
ans.append(token_table[s])
|
|
||||||
num_tokens = len(ans)
|
|
||||||
if 0 in ans:
|
|
||||||
num_tokens -= 1
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -1899,3 +1899,22 @@ def symlink_or_copy(exp_dir: Path, src: str, dst: str):
|
|||||||
except OSError:
|
except OSError:
|
||||||
copyfile(src=exp_dir / src, dst=exp_dir / dst)
|
copyfile(src=exp_dir / src, dst=exp_dir / dst)
|
||||||
os.close(dir_fd)
|
os.close(dir_fd)
|
||||||
|
|
||||||
|
def num_tokens(
|
||||||
|
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
||||||
|
) -> int:
|
||||||
|
"""Return the number of tokens excluding those from
|
||||||
|
disambiguation symbols.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
0 is not a token ID so it is excluded from the return value.
|
||||||
|
"""
|
||||||
|
symbols = token_table.symbols
|
||||||
|
ans = []
|
||||||
|
for s in symbols:
|
||||||
|
if not disambig_pattern.match(s):
|
||||||
|
ans.append(token_table[s])
|
||||||
|
num_tokens = len(ans)
|
||||||
|
if 0 in ans:
|
||||||
|
num_tokens -= 1
|
||||||
|
return num_tokens
|
Loading…
x
Reference in New Issue
Block a user