mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Merge branch 'k2-fsa:master' into phone2
This commit is contained in:
commit
91f4e52fcf
@ -38,7 +38,7 @@ Please fix any issues reported by the check tools.
|
||||
.. HINT::
|
||||
|
||||
Some of the check tools, i.e., ``black`` and ``isort`` will modify
|
||||
the files to be commited **in-place**. So please run ``git status``
|
||||
the files to be committed **in-place**. So please run ``git status``
|
||||
after failure to see which file has been modified by the tools
|
||||
before you make any further changes.
|
||||
|
||||
|
@ -56,7 +56,7 @@ during decoding for transducer model:
|
||||
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
|
||||
\lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right)
|
||||
|
||||
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR,
|
||||
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Compared to DR,
|
||||
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
|
||||
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
|
||||
As a bi-gram is much faster to evaluate, LODR is usually much faster.
|
||||
|
@ -125,7 +125,7 @@ Python code. We have also set up ``PATH`` so that you can use
|
||||
.. caution::
|
||||
|
||||
Please don't use `<https://github.com/tencent/ncnn>`_.
|
||||
We have made some modifications to the offical `ncnn`_.
|
||||
We have made some modifications to the official `ncnn`_.
|
||||
|
||||
We will synchronize `<https://github.com/csukuangfj/ncnn>`_ periodically
|
||||
with the official one.
|
||||
|
@ -1,8 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Mingshuang Luo,)
|
||||
# Zengwei Yao)
|
||||
# Mingshuang Luo
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -21,21 +22,35 @@
|
||||
Usage:
|
||||
|
||||
|
||||
# For mix precision training:
|
||||
# For mix precision training, using MCP style transcript:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./zipformer/train.py \
|
||||
./zipformer_prompt_asr/train_baseline.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp \
|
||||
--exp-dir zipformer_prompt_asr/exp \
|
||||
--transcript-style MCP \
|
||||
--max-duration 1000
|
||||
|
||||
# For mix precision training, using UC style transcript:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./zipformer_prompt_asr/train_baseline.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer_prompt_asr/exp \
|
||||
--transcript-style UC \
|
||||
--max-duration 1000
|
||||
|
||||
# To train a streaming model
|
||||
|
||||
./zipformer/train.py \
|
||||
./zipformer_prompt_asr/train_baseline.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
@ -100,7 +115,7 @@ from icefall.utils import (
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
def get_first(
|
||||
def get_mixed_cased_with_punc(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: Optional[str] = None,
|
||||
@ -479,6 +494,16 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transcript-style",
|
||||
type=str,
|
||||
default="UC",
|
||||
choices=["UC", "MCP"],
|
||||
help="""The transcript style used for training. UC stands for upper-cased text w/o punctuations,
|
||||
MCP stands for mix-cased text with punctuation.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -1223,7 +1248,11 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
text_sampling_func = get_upper_only_alpha
|
||||
if params.transcript_style == "UC":
|
||||
text_sampling_func = get_upper_only_alpha
|
||||
else:
|
||||
text_sampling_func = get_mixed_cased_with_punc
|
||||
logging.info(f"Using {params.transcript_style} style for training.")
|
||||
logging.info(f"Text sampling func: {text_sampling_func}")
|
||||
train_dl = libriheavy.train_dataloaders(
|
||||
train_cuts,
|
||||
|
@ -655,8 +655,12 @@ def load_model_params(
|
||||
dst_state_dict = model.state_dict()
|
||||
for module in init_modules:
|
||||
logging.info(f"Loading parameters starting with prefix {module}")
|
||||
src_keys = [k for k in src_state_dict.keys() if k.startswith(module)]
|
||||
dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)]
|
||||
src_keys = [
|
||||
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
|
||||
]
|
||||
dst_keys = [
|
||||
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
|
||||
]
|
||||
assert set(src_keys) == set(dst_keys) # two sets should match exactly
|
||||
for key in src_keys:
|
||||
dst_state_dict[key] = src_state_dict.pop(key)
|
||||
@ -1089,6 +1093,9 @@ def run(rank, world_size, args):
|
||||
checkpoints = load_model_params(
|
||||
ckpt=params.finetune_ckpt, model=model, init_modules=modules
|
||||
)
|
||||
if rank == 0:
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
else:
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
|
@ -114,6 +114,9 @@ class Transducer(nn.Module):
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
# x.T_dim == max(x_len)
|
||||
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens, x_lens.max())
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
|
@ -17,28 +17,33 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import random
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||
)
|
||||
from scaling import (
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
)
|
||||
from scaling import (
|
||||
ActivationDropoutAndLinear,
|
||||
Balancer,
|
||||
BiasNorm,
|
||||
Dropout2,
|
||||
ChunkCausalDepthwiseConv1d,
|
||||
ActivationDropoutAndLinear,
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
Dropout2,
|
||||
FloatLike,
|
||||
ScheduledFloat,
|
||||
Whiten,
|
||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||
convert_num_channels,
|
||||
limit_param_value,
|
||||
penalize_abs_values_gt,
|
||||
softmax,
|
||||
ScheduledFloat,
|
||||
FloatLike,
|
||||
limit_param_value,
|
||||
convert_num_channels,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -2098,7 +2103,7 @@ class NonlinAttention(nn.Module):
|
||||
(seq_len, batch_size, _) = x.shape
|
||||
hidden_channels = self.hidden_channels
|
||||
|
||||
s, x, y = x.chunk(3, dim=-1)
|
||||
s, x, y = x.chunk(3, dim=2)
|
||||
|
||||
# s will go through tanh.
|
||||
|
||||
@ -2151,7 +2156,7 @@ class NonlinAttention(nn.Module):
|
||||
(seq_len, batch_size, _) = x.shape
|
||||
hidden_channels = self.hidden_channels
|
||||
|
||||
s, x, y = x.chunk(3, dim=-1)
|
||||
s, x, y = x.chunk(3, dim=2)
|
||||
|
||||
# s will go through tanh.
|
||||
s = self.tanh(s)
|
||||
@ -2308,7 +2313,7 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
||||
|
||||
x, s = x.chunk(2, dim=-1)
|
||||
x, s = x.chunk(2, dim=2)
|
||||
s = self.balancer1(s)
|
||||
s = self.sigmoid(s)
|
||||
x = self.activation1(x) # identity.
|
||||
|
@ -203,7 +203,7 @@ def get_parser():
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
@ -78,7 +78,7 @@ def add_finetune_arguments(parser: argparse.ArgumentParser):
|
||||
default=None,
|
||||
help="""
|
||||
Modules to be initialized. It matches all parameters starting with
|
||||
a specific key. The keys are given with Comma seperated. If None,
|
||||
a specific key. The keys are given with Comma separated. If None,
|
||||
all modules will be initialised. For example, if you only want to
|
||||
initialise all parameters staring with "encoder", use "encoder";
|
||||
if you want to initialise parameters starting with encoder or decoder,
|
||||
@ -498,8 +498,12 @@ def load_model_params(
|
||||
dst_state_dict = model.state_dict()
|
||||
for module in init_modules:
|
||||
logging.info(f"Loading parameters starting with prefix {module}")
|
||||
src_keys = [k for k in src_state_dict.keys() if k.startswith(module)]
|
||||
dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)]
|
||||
src_keys = [
|
||||
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
|
||||
]
|
||||
dst_keys = [
|
||||
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
|
||||
]
|
||||
assert set(src_keys) == set(dst_keys) # two sets should match exactly
|
||||
for key in src_keys:
|
||||
dst_state_dict[key] = src_state_dict.pop(key)
|
||||
|
@ -244,12 +244,22 @@ class TensorDiagnostic(object):
|
||||
|
||||
if stats_type == "eigs":
|
||||
try:
|
||||
eigs, _ = torch.symeig(stats)
|
||||
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'):
|
||||
eigs, _ = torch.linalg.eigh(stats)
|
||||
else:
|
||||
eigs, _ = torch.symeig(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
except: # noqa
|
||||
print("Error getting eigenvalues, trying another method.")
|
||||
eigs, _ = torch.eig(stats)
|
||||
stats = eigs.norm(dim=1).sqrt()
|
||||
print(
|
||||
"Error getting eigenvalues, trying another method."
|
||||
)
|
||||
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'):
|
||||
eigs, _ = torch.linalg.eig(stats)
|
||||
eigs = eigs.abs()
|
||||
else:
|
||||
eigs, _ = torch.eig(stats)
|
||||
eigs = eigs.norm(dim=1)
|
||||
stats = eigs.sqrt()
|
||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||
|
||||
if stats_type in ["rms", "stddev"]:
|
||||
@ -569,11 +579,10 @@ def attach_diagnostics(
|
||||
)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if o.dtype in (torch.float32, torch.float16, torch.float64):
|
||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(
|
||||
o, class_name=get_class_name(_module)
|
||||
)
|
||||
|
||||
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
|
||||
class_name=get_class_name(_module))
|
||||
|
||||
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
|
||||
if isinstance(_output, tuple) and len(_output) == 1:
|
||||
_output = _output[0]
|
||||
@ -587,11 +596,9 @@ def attach_diagnostics(
|
||||
)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if o.dtype in (torch.float32, torch.float16, torch.float64):
|
||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
|
||||
o, class_name=get_class_name(_module)
|
||||
)
|
||||
|
||||
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
|
||||
class_name=get_class_name(_module))
|
||||
module.register_forward_hook(forward_hook)
|
||||
module.register_backward_hook(backward_hook)
|
||||
|
||||
|
@ -498,7 +498,7 @@ def store_transcripts(
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
with open(filename, "w", encoding="utf8") as f:
|
||||
for cut_id, ref, hyp in texts:
|
||||
if char_level:
|
||||
ref = list("".join(ref))
|
||||
@ -523,7 +523,7 @@ def store_transcripts_and_timestamps(
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
with open(filename, "w", encoding="utf8") as f:
|
||||
for cut_id, ref, hyp, time_ref, time_hyp in texts:
|
||||
print(f"{cut_id}:\tref={ref}", file=f)
|
||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||
@ -1447,7 +1447,7 @@ def get_parameter_groups_with_lrs(
|
||||
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of
|
||||
named-parameters; we can, if needed, create a version without the names).
|
||||
|
||||
It provides a way to specifiy learning-rate scales inside the module, so that if
|
||||
It provides a way to specify learning-rate scales inside the module, so that if
|
||||
any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will
|
||||
scale the LR of any parameters inside that module or its submodules. Note: you
|
||||
can set module parameters outside the __init__ function, e.g.:
|
||||
@ -1607,10 +1607,10 @@ def tokenize_by_bpe_model(
|
||||
chars = pattern.split(txt.upper())
|
||||
mix_chars = [w for w in chars if len(w.strip()) > 0]
|
||||
for ch_or_w in mix_chars:
|
||||
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
|
||||
# ch_or_w is a single CJK character(i.e., "你"), do nothing.
|
||||
if pattern.fullmatch(ch_or_w) is not None:
|
||||
tokens.append(ch_or_w)
|
||||
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
|
||||
# ch_or_w contains non-CJK characters(i.e., " IT'S OKAY "),
|
||||
# encode ch_or_w using bpe_model.
|
||||
else:
|
||||
for p in sp.encode_as_pieces(ch_or_w):
|
||||
@ -1624,7 +1624,7 @@ def tokenize_by_CJK_char(line: str) -> str:
|
||||
"""
|
||||
Tokenize a line of text with CJK char.
|
||||
|
||||
Note: All return charaters will be upper case.
|
||||
Note: All return characters will be upper case.
|
||||
|
||||
Example:
|
||||
input = "你好世界是 hello world 的中文"
|
||||
@ -1917,7 +1917,7 @@ def parse_bpe_timestamps_and_texts(
|
||||
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||
containing multiple FSAs, which is expected to be the result
|
||||
of k2.shortest_path (otherwise the returned values won't
|
||||
be meaningful). Its attribtutes `labels` and `aux_labels`
|
||||
be meaningful). Its attributes `labels` and `aux_labels`
|
||||
are both BPE tokens.
|
||||
sp:
|
||||
The BPE model.
|
||||
@ -1977,7 +1977,7 @@ def parse_timestamps_and_texts(
|
||||
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||
containing multiple FSAs, which is expected to be the result
|
||||
of k2.shortest_path (otherwise the returned values won't
|
||||
be meaningful). Attribtute `labels` is the prediction unit,
|
||||
be meaningful). Attribute `labels` is the prediction unit,
|
||||
e.g., phone or BPE tokens. Attribute `aux_labels` is the word index.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
@ -2045,7 +2045,7 @@ def parse_fsa_timestamps_and_texts(
|
||||
) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
|
||||
"""Parse timestamps (in seconds) and texts for given decoded fsa paths.
|
||||
Currently it supports two cases:
|
||||
(1) ctc-decoding, the attribtutes `labels` and `aux_labels`
|
||||
(1) ctc-decoding, the attributes `labels` and `aux_labels`
|
||||
are both BPE tokens. In this case, sp should be provided.
|
||||
(2) HLG-based 1best, the attribtute `labels` is the prediction unit,
|
||||
e.g., phone or BPE tokens; attribute `aux_labels` is the word index.
|
||||
|
Loading…
x
Reference in New Issue
Block a user