mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Merge branch 'k2-fsa:master' into phone2
This commit is contained in:
commit
91f4e52fcf
@ -38,7 +38,7 @@ Please fix any issues reported by the check tools.
|
|||||||
.. HINT::
|
.. HINT::
|
||||||
|
|
||||||
Some of the check tools, i.e., ``black`` and ``isort`` will modify
|
Some of the check tools, i.e., ``black`` and ``isort`` will modify
|
||||||
the files to be commited **in-place**. So please run ``git status``
|
the files to be committed **in-place**. So please run ``git status``
|
||||||
after failure to see which file has been modified by the tools
|
after failure to see which file has been modified by the tools
|
||||||
before you make any further changes.
|
before you make any further changes.
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ during decoding for transducer model:
|
|||||||
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
|
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
|
||||||
\lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right)
|
\lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right)
|
||||||
|
|
||||||
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR,
|
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Compared to DR,
|
||||||
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
|
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
|
||||||
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
|
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
|
||||||
As a bi-gram is much faster to evaluate, LODR is usually much faster.
|
As a bi-gram is much faster to evaluate, LODR is usually much faster.
|
||||||
|
@ -125,7 +125,7 @@ Python code. We have also set up ``PATH`` so that you can use
|
|||||||
.. caution::
|
.. caution::
|
||||||
|
|
||||||
Please don't use `<https://github.com/tencent/ncnn>`_.
|
Please don't use `<https://github.com/tencent/ncnn>`_.
|
||||||
We have made some modifications to the offical `ncnn`_.
|
We have made some modifications to the official `ncnn`_.
|
||||||
|
|
||||||
We will synchronize `<https://github.com/csukuangfj/ncnn>`_ periodically
|
We will synchronize `<https://github.com/csukuangfj/ncnn>`_ periodically
|
||||||
with the official one.
|
with the official one.
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
# Wei Kang,
|
# Wei Kang,
|
||||||
# Mingshuang Luo,)
|
# Mingshuang Luo
|
||||||
# Zengwei Yao)
|
# Zengwei Yao,
|
||||||
|
# Xiaoyu Yang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -21,21 +22,35 @@
|
|||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
|
|
||||||
# For mix precision training:
|
# For mix precision training, using MCP style transcript:
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
./zipformer/train.py \
|
./zipformer_prompt_asr/train_baseline.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir zipformer/exp \
|
--exp-dir zipformer_prompt_asr/exp \
|
||||||
|
--transcript-style MCP \
|
||||||
|
--max-duration 1000
|
||||||
|
|
||||||
|
# For mix precision training, using UC style transcript:
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
|
./zipformer_prompt_asr/train_baseline.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--exp-dir zipformer_prompt_asr/exp \
|
||||||
|
--transcript-style UC \
|
||||||
--max-duration 1000
|
--max-duration 1000
|
||||||
|
|
||||||
# To train a streaming model
|
# To train a streaming model
|
||||||
|
|
||||||
./zipformer/train.py \
|
./zipformer_prompt_asr/train_baseline.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
@ -100,7 +115,7 @@ from icefall.utils import (
|
|||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
|
|
||||||
def get_first(
|
def get_mixed_cased_with_punc(
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
pre_texts: List[str],
|
pre_texts: List[str],
|
||||||
context_list: Optional[str] = None,
|
context_list: Optional[str] = None,
|
||||||
@ -479,6 +494,16 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--transcript-style",
|
||||||
|
type=str,
|
||||||
|
default="UC",
|
||||||
|
choices=["UC", "MCP"],
|
||||||
|
help="""The transcript style used for training. UC stands for upper-cased text w/o punctuations,
|
||||||
|
MCP stands for mix-cased text with punctuation.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -1223,7 +1248,11 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
|
|
||||||
text_sampling_func = get_upper_only_alpha
|
if params.transcript_style == "UC":
|
||||||
|
text_sampling_func = get_upper_only_alpha
|
||||||
|
else:
|
||||||
|
text_sampling_func = get_mixed_cased_with_punc
|
||||||
|
logging.info(f"Using {params.transcript_style} style for training.")
|
||||||
logging.info(f"Text sampling func: {text_sampling_func}")
|
logging.info(f"Text sampling func: {text_sampling_func}")
|
||||||
train_dl = libriheavy.train_dataloaders(
|
train_dl = libriheavy.train_dataloaders(
|
||||||
train_cuts,
|
train_cuts,
|
||||||
|
@ -655,8 +655,12 @@ def load_model_params(
|
|||||||
dst_state_dict = model.state_dict()
|
dst_state_dict = model.state_dict()
|
||||||
for module in init_modules:
|
for module in init_modules:
|
||||||
logging.info(f"Loading parameters starting with prefix {module}")
|
logging.info(f"Loading parameters starting with prefix {module}")
|
||||||
src_keys = [k for k in src_state_dict.keys() if k.startswith(module)]
|
src_keys = [
|
||||||
dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)]
|
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
|
||||||
|
]
|
||||||
|
dst_keys = [
|
||||||
|
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
|
||||||
|
]
|
||||||
assert set(src_keys) == set(dst_keys) # two sets should match exactly
|
assert set(src_keys) == set(dst_keys) # two sets should match exactly
|
||||||
for key in src_keys:
|
for key in src_keys:
|
||||||
dst_state_dict[key] = src_state_dict.pop(key)
|
dst_state_dict[key] = src_state_dict.pop(key)
|
||||||
@ -1089,6 +1093,9 @@ def run(rank, world_size, args):
|
|||||||
checkpoints = load_model_params(
|
checkpoints = load_model_params(
|
||||||
ckpt=params.finetune_ckpt, model=model, init_modules=modules
|
ckpt=params.finetune_ckpt, model=model, init_modules=modules
|
||||||
)
|
)
|
||||||
|
if rank == 0:
|
||||||
|
# model_avg is only used with rank 0
|
||||||
|
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||||
else:
|
else:
|
||||||
assert params.start_epoch > 0, params.start_epoch
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
checkpoints = load_checkpoint_if_available(
|
checkpoints = load_checkpoint_if_available(
|
||||||
|
@ -114,6 +114,9 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||||
|
|
||||||
|
# x.T_dim == max(x_len)
|
||||||
|
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens, x_lens.max())
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
|
@ -17,28 +17,33 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
import logging
|
|
||||||
import torch
|
import torch
|
||||||
import random
|
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import (
|
from scaling import (
|
||||||
|
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||||
|
)
|
||||||
|
from scaling import (
|
||||||
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
|
)
|
||||||
|
from scaling import (
|
||||||
|
ActivationDropoutAndLinear,
|
||||||
Balancer,
|
Balancer,
|
||||||
BiasNorm,
|
BiasNorm,
|
||||||
Dropout2,
|
|
||||||
ChunkCausalDepthwiseConv1d,
|
ChunkCausalDepthwiseConv1d,
|
||||||
ActivationDropoutAndLinear,
|
Dropout2,
|
||||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
FloatLike,
|
||||||
|
ScheduledFloat,
|
||||||
Whiten,
|
Whiten,
|
||||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
convert_num_channels,
|
||||||
|
limit_param_value,
|
||||||
penalize_abs_values_gt,
|
penalize_abs_values_gt,
|
||||||
softmax,
|
softmax,
|
||||||
ScheduledFloat,
|
|
||||||
FloatLike,
|
|
||||||
limit_param_value,
|
|
||||||
convert_num_channels,
|
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -2098,7 +2103,7 @@ class NonlinAttention(nn.Module):
|
|||||||
(seq_len, batch_size, _) = x.shape
|
(seq_len, batch_size, _) = x.shape
|
||||||
hidden_channels = self.hidden_channels
|
hidden_channels = self.hidden_channels
|
||||||
|
|
||||||
s, x, y = x.chunk(3, dim=-1)
|
s, x, y = x.chunk(3, dim=2)
|
||||||
|
|
||||||
# s will go through tanh.
|
# s will go through tanh.
|
||||||
|
|
||||||
@ -2151,7 +2156,7 @@ class NonlinAttention(nn.Module):
|
|||||||
(seq_len, batch_size, _) = x.shape
|
(seq_len, batch_size, _) = x.shape
|
||||||
hidden_channels = self.hidden_channels
|
hidden_channels = self.hidden_channels
|
||||||
|
|
||||||
s, x, y = x.chunk(3, dim=-1)
|
s, x, y = x.chunk(3, dim=2)
|
||||||
|
|
||||||
# s will go through tanh.
|
# s will go through tanh.
|
||||||
s = self.tanh(s)
|
s = self.tanh(s)
|
||||||
@ -2308,7 +2313,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
x = self.in_proj(x) # (time, batch, 2*channels)
|
||||||
|
|
||||||
x, s = x.chunk(2, dim=-1)
|
x, s = x.chunk(2, dim=2)
|
||||||
s = self.balancer1(s)
|
s = self.balancer1(s)
|
||||||
s = self.sigmoid(s)
|
s = self.sigmoid(s)
|
||||||
x = self.activation1(x) # identity.
|
x = self.activation1(x) # identity.
|
||||||
|
@ -203,7 +203,7 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""An interger indicating how many candidates we will keep for each
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
frame. Used only when --decoding-method is beam_search or
|
frame. Used only when --decoding-method is beam_search or
|
||||||
modified_beam_search.""",
|
modified_beam_search.""",
|
||||||
)
|
)
|
||||||
|
@ -78,7 +78,7 @@ def add_finetune_arguments(parser: argparse.ArgumentParser):
|
|||||||
default=None,
|
default=None,
|
||||||
help="""
|
help="""
|
||||||
Modules to be initialized. It matches all parameters starting with
|
Modules to be initialized. It matches all parameters starting with
|
||||||
a specific key. The keys are given with Comma seperated. If None,
|
a specific key. The keys are given with Comma separated. If None,
|
||||||
all modules will be initialised. For example, if you only want to
|
all modules will be initialised. For example, if you only want to
|
||||||
initialise all parameters staring with "encoder", use "encoder";
|
initialise all parameters staring with "encoder", use "encoder";
|
||||||
if you want to initialise parameters starting with encoder or decoder,
|
if you want to initialise parameters starting with encoder or decoder,
|
||||||
@ -498,8 +498,12 @@ def load_model_params(
|
|||||||
dst_state_dict = model.state_dict()
|
dst_state_dict = model.state_dict()
|
||||||
for module in init_modules:
|
for module in init_modules:
|
||||||
logging.info(f"Loading parameters starting with prefix {module}")
|
logging.info(f"Loading parameters starting with prefix {module}")
|
||||||
src_keys = [k for k in src_state_dict.keys() if k.startswith(module)]
|
src_keys = [
|
||||||
dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)]
|
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
|
||||||
|
]
|
||||||
|
dst_keys = [
|
||||||
|
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
|
||||||
|
]
|
||||||
assert set(src_keys) == set(dst_keys) # two sets should match exactly
|
assert set(src_keys) == set(dst_keys) # two sets should match exactly
|
||||||
for key in src_keys:
|
for key in src_keys:
|
||||||
dst_state_dict[key] = src_state_dict.pop(key)
|
dst_state_dict[key] = src_state_dict.pop(key)
|
||||||
|
@ -244,12 +244,22 @@ class TensorDiagnostic(object):
|
|||||||
|
|
||||||
if stats_type == "eigs":
|
if stats_type == "eigs":
|
||||||
try:
|
try:
|
||||||
eigs, _ = torch.symeig(stats)
|
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'):
|
||||||
|
eigs, _ = torch.linalg.eigh(stats)
|
||||||
|
else:
|
||||||
|
eigs, _ = torch.symeig(stats)
|
||||||
stats = eigs.abs().sqrt()
|
stats = eigs.abs().sqrt()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
print("Error getting eigenvalues, trying another method.")
|
print(
|
||||||
eigs, _ = torch.eig(stats)
|
"Error getting eigenvalues, trying another method."
|
||||||
stats = eigs.norm(dim=1).sqrt()
|
)
|
||||||
|
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'):
|
||||||
|
eigs, _ = torch.linalg.eig(stats)
|
||||||
|
eigs = eigs.abs()
|
||||||
|
else:
|
||||||
|
eigs, _ = torch.eig(stats)
|
||||||
|
eigs = eigs.norm(dim=1)
|
||||||
|
stats = eigs.sqrt()
|
||||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||||
|
|
||||||
if stats_type in ["rms", "stddev"]:
|
if stats_type in ["rms", "stddev"]:
|
||||||
@ -569,11 +579,10 @@ def attach_diagnostics(
|
|||||||
)
|
)
|
||||||
elif isinstance(_output, tuple):
|
elif isinstance(_output, tuple):
|
||||||
for i, o in enumerate(_output):
|
for i, o in enumerate(_output):
|
||||||
if o.dtype in (torch.float32, torch.float16, torch.float64):
|
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(
|
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
|
||||||
o, class_name=get_class_name(_module)
|
class_name=get_class_name(_module))
|
||||||
)
|
|
||||||
|
|
||||||
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
|
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
|
||||||
if isinstance(_output, tuple) and len(_output) == 1:
|
if isinstance(_output, tuple) and len(_output) == 1:
|
||||||
_output = _output[0]
|
_output = _output[0]
|
||||||
@ -587,11 +596,9 @@ def attach_diagnostics(
|
|||||||
)
|
)
|
||||||
elif isinstance(_output, tuple):
|
elif isinstance(_output, tuple):
|
||||||
for i, o in enumerate(_output):
|
for i, o in enumerate(_output):
|
||||||
if o.dtype in (torch.float32, torch.float16, torch.float64):
|
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
|
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
|
||||||
o, class_name=get_class_name(_module)
|
class_name=get_class_name(_module))
|
||||||
)
|
|
||||||
|
|
||||||
module.register_forward_hook(forward_hook)
|
module.register_forward_hook(forward_hook)
|
||||||
module.register_backward_hook(backward_hook)
|
module.register_backward_hook(backward_hook)
|
||||||
|
|
||||||
|
@ -498,7 +498,7 @@ def store_transcripts(
|
|||||||
Returns:
|
Returns:
|
||||||
Return None.
|
Return None.
|
||||||
"""
|
"""
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w", encoding="utf8") as f:
|
||||||
for cut_id, ref, hyp in texts:
|
for cut_id, ref, hyp in texts:
|
||||||
if char_level:
|
if char_level:
|
||||||
ref = list("".join(ref))
|
ref = list("".join(ref))
|
||||||
@ -523,7 +523,7 @@ def store_transcripts_and_timestamps(
|
|||||||
Returns:
|
Returns:
|
||||||
Return None.
|
Return None.
|
||||||
"""
|
"""
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w", encoding="utf8") as f:
|
||||||
for cut_id, ref, hyp, time_ref, time_hyp in texts:
|
for cut_id, ref, hyp, time_ref, time_hyp in texts:
|
||||||
print(f"{cut_id}:\tref={ref}", file=f)
|
print(f"{cut_id}:\tref={ref}", file=f)
|
||||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||||
@ -1447,7 +1447,7 @@ def get_parameter_groups_with_lrs(
|
|||||||
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of
|
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of
|
||||||
named-parameters; we can, if needed, create a version without the names).
|
named-parameters; we can, if needed, create a version without the names).
|
||||||
|
|
||||||
It provides a way to specifiy learning-rate scales inside the module, so that if
|
It provides a way to specify learning-rate scales inside the module, so that if
|
||||||
any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will
|
any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will
|
||||||
scale the LR of any parameters inside that module or its submodules. Note: you
|
scale the LR of any parameters inside that module or its submodules. Note: you
|
||||||
can set module parameters outside the __init__ function, e.g.:
|
can set module parameters outside the __init__ function, e.g.:
|
||||||
@ -1607,10 +1607,10 @@ def tokenize_by_bpe_model(
|
|||||||
chars = pattern.split(txt.upper())
|
chars = pattern.split(txt.upper())
|
||||||
mix_chars = [w for w in chars if len(w.strip()) > 0]
|
mix_chars = [w for w in chars if len(w.strip()) > 0]
|
||||||
for ch_or_w in mix_chars:
|
for ch_or_w in mix_chars:
|
||||||
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
|
# ch_or_w is a single CJK character(i.e., "你"), do nothing.
|
||||||
if pattern.fullmatch(ch_or_w) is not None:
|
if pattern.fullmatch(ch_or_w) is not None:
|
||||||
tokens.append(ch_or_w)
|
tokens.append(ch_or_w)
|
||||||
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
|
# ch_or_w contains non-CJK characters(i.e., " IT'S OKAY "),
|
||||||
# encode ch_or_w using bpe_model.
|
# encode ch_or_w using bpe_model.
|
||||||
else:
|
else:
|
||||||
for p in sp.encode_as_pieces(ch_or_w):
|
for p in sp.encode_as_pieces(ch_or_w):
|
||||||
@ -1624,7 +1624,7 @@ def tokenize_by_CJK_char(line: str) -> str:
|
|||||||
"""
|
"""
|
||||||
Tokenize a line of text with CJK char.
|
Tokenize a line of text with CJK char.
|
||||||
|
|
||||||
Note: All return charaters will be upper case.
|
Note: All return characters will be upper case.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
input = "你好世界是 hello world 的中文"
|
input = "你好世界是 hello world 的中文"
|
||||||
@ -1917,7 +1917,7 @@ def parse_bpe_timestamps_and_texts(
|
|||||||
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||||
containing multiple FSAs, which is expected to be the result
|
containing multiple FSAs, which is expected to be the result
|
||||||
of k2.shortest_path (otherwise the returned values won't
|
of k2.shortest_path (otherwise the returned values won't
|
||||||
be meaningful). Its attribtutes `labels` and `aux_labels`
|
be meaningful). Its attributes `labels` and `aux_labels`
|
||||||
are both BPE tokens.
|
are both BPE tokens.
|
||||||
sp:
|
sp:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
@ -1977,7 +1977,7 @@ def parse_timestamps_and_texts(
|
|||||||
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||||
containing multiple FSAs, which is expected to be the result
|
containing multiple FSAs, which is expected to be the result
|
||||||
of k2.shortest_path (otherwise the returned values won't
|
of k2.shortest_path (otherwise the returned values won't
|
||||||
be meaningful). Attribtute `labels` is the prediction unit,
|
be meaningful). Attribute `labels` is the prediction unit,
|
||||||
e.g., phone or BPE tokens. Attribute `aux_labels` is the word index.
|
e.g., phone or BPE tokens. Attribute `aux_labels` is the word index.
|
||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
@ -2045,7 +2045,7 @@ def parse_fsa_timestamps_and_texts(
|
|||||||
) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
|
) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
|
||||||
"""Parse timestamps (in seconds) and texts for given decoded fsa paths.
|
"""Parse timestamps (in seconds) and texts for given decoded fsa paths.
|
||||||
Currently it supports two cases:
|
Currently it supports two cases:
|
||||||
(1) ctc-decoding, the attribtutes `labels` and `aux_labels`
|
(1) ctc-decoding, the attributes `labels` and `aux_labels`
|
||||||
are both BPE tokens. In this case, sp should be provided.
|
are both BPE tokens. In this case, sp should be provided.
|
||||||
(2) HLG-based 1best, the attribtute `labels` is the prediction unit,
|
(2) HLG-based 1best, the attribtute `labels` is the prediction unit,
|
||||||
e.g., phone or BPE tokens; attribute `aux_labels` is the word index.
|
e.g., phone or BPE tokens; attribute `aux_labels` is the word index.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user