Merge branch 'k2-fsa:master' into phone2

This commit is contained in:
Yifan Yang 2023-10-20 09:09:31 -05:00 committed by GitHub
commit 91f4e52fcf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 108 additions and 53 deletions

View File

@ -38,7 +38,7 @@ Please fix any issues reported by the check tools.
.. HINT::
Some of the check tools, i.e., ``black`` and ``isort`` will modify
the files to be commited **in-place**. So please run ``git status``
the files to be committed **in-place**. So please run ``git status``
after failure to see which file has been modified by the tools
before you make any further changes.

View File

@ -56,7 +56,7 @@ during decoding for transducer model:
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
\lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right)
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR,
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Compared to DR,
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
As a bi-gram is much faster to evaluate, LODR is usually much faster.

View File

@ -125,7 +125,7 @@ Python code. We have also set up ``PATH`` so that you can use
.. caution::
Please don't use `<https://github.com/tencent/ncnn>`_.
We have made some modifications to the offical `ncnn`_.
We have made some modifications to the official `ncnn`_.
We will synchronize `<https://github.com/csukuangfj/ncnn>`_ periodically
with the official one.

View File

@ -1,8 +1,9 @@
#!/usr/bin/env python3
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,)
# Zengwei Yao)
# Mingshuang Luo
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -21,21 +22,35 @@
Usage:
# For mix precision training:
# For mix precision training, using MCP style transcript:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
./zipformer_prompt_asr/train_baseline.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--exp-dir zipformer_prompt_asr/exp \
--transcript-style MCP \
--max-duration 1000
# For mix precision training, using UC style transcript:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer_prompt_asr/train_baseline.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer_prompt_asr/exp \
--transcript-style UC \
--max-duration 1000
# To train a streaming model
./zipformer/train.py \
./zipformer_prompt_asr/train_baseline.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
@ -100,7 +115,7 @@ from icefall.utils import (
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def get_first(
def get_mixed_cased_with_punc(
texts: List[str],
pre_texts: List[str],
context_list: Optional[str] = None,
@ -479,6 +494,16 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--transcript-style",
type=str,
default="UC",
choices=["UC", "MCP"],
help="""The transcript style used for training. UC stands for upper-cased text w/o punctuations,
MCP stands for mix-cased text with punctuation.
""",
)
add_model_arguments(parser)
return parser
@ -1223,7 +1248,11 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
text_sampling_func = get_upper_only_alpha
if params.transcript_style == "UC":
text_sampling_func = get_upper_only_alpha
else:
text_sampling_func = get_mixed_cased_with_punc
logging.info(f"Using {params.transcript_style} style for training.")
logging.info(f"Text sampling func: {text_sampling_func}")
train_dl = libriheavy.train_dataloaders(
train_cuts,

View File

@ -655,8 +655,12 @@ def load_model_params(
dst_state_dict = model.state_dict()
for module in init_modules:
logging.info(f"Loading parameters starting with prefix {module}")
src_keys = [k for k in src_state_dict.keys() if k.startswith(module)]
dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)]
src_keys = [
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
]
dst_keys = [
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
]
assert set(src_keys) == set(dst_keys) # two sets should match exactly
for key in src_keys:
dst_state_dict[key] = src_state_dict.pop(key)
@ -1089,6 +1093,9 @@ def run(rank, world_size, args):
checkpoints = load_model_params(
ckpt=params.finetune_ckpt, model=model, init_modules=modules
)
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
else:
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(

View File

@ -114,6 +114,9 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0
# x.T_dim == max(x_len)
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens, x_lens.max())
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)

View File

@ -17,28 +17,33 @@
# limitations under the License.
import copy
import logging
import math
import random
import warnings
from typing import List, Optional, Tuple, Union
import logging
import torch
import random
from encoder_interface import EncoderInterface
from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
)
from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
)
from scaling import (
ActivationDropoutAndLinear,
Balancer,
BiasNorm,
Dropout2,
ChunkCausalDepthwiseConv1d,
ActivationDropoutAndLinear,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
Dropout2,
FloatLike,
ScheduledFloat,
Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
convert_num_channels,
limit_param_value,
penalize_abs_values_gt,
softmax,
ScheduledFloat,
FloatLike,
limit_param_value,
convert_num_channels,
)
from torch import Tensor, nn
@ -2098,7 +2103,7 @@ class NonlinAttention(nn.Module):
(seq_len, batch_size, _) = x.shape
hidden_channels = self.hidden_channels
s, x, y = x.chunk(3, dim=-1)
s, x, y = x.chunk(3, dim=2)
# s will go through tanh.
@ -2151,7 +2156,7 @@ class NonlinAttention(nn.Module):
(seq_len, batch_size, _) = x.shape
hidden_channels = self.hidden_channels
s, x, y = x.chunk(3, dim=-1)
s, x, y = x.chunk(3, dim=2)
# s will go through tanh.
s = self.tanh(s)
@ -2308,7 +2313,7 @@ class ConvolutionModule(nn.Module):
x = self.in_proj(x) # (time, batch, 2*channels)
x, s = x.chunk(2, dim=-1)
x, s = x.chunk(2, dim=2)
s = self.balancer1(s)
s = self.sigmoid(s)
x = self.activation1(x) # identity.

View File

@ -203,7 +203,7 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)

View File

@ -78,7 +78,7 @@ def add_finetune_arguments(parser: argparse.ArgumentParser):
default=None,
help="""
Modules to be initialized. It matches all parameters starting with
a specific key. The keys are given with Comma seperated. If None,
a specific key. The keys are given with Comma separated. If None,
all modules will be initialised. For example, if you only want to
initialise all parameters staring with "encoder", use "encoder";
if you want to initialise parameters starting with encoder or decoder,
@ -498,8 +498,12 @@ def load_model_params(
dst_state_dict = model.state_dict()
for module in init_modules:
logging.info(f"Loading parameters starting with prefix {module}")
src_keys = [k for k in src_state_dict.keys() if k.startswith(module)]
dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)]
src_keys = [
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
]
dst_keys = [
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
]
assert set(src_keys) == set(dst_keys) # two sets should match exactly
for key in src_keys:
dst_state_dict[key] = src_state_dict.pop(key)

View File

@ -244,12 +244,22 @@ class TensorDiagnostic(object):
if stats_type == "eigs":
try:
eigs, _ = torch.symeig(stats)
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'):
eigs, _ = torch.linalg.eigh(stats)
else:
eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print("Error getting eigenvalues, trying another method.")
eigs, _ = torch.eig(stats)
stats = eigs.norm(dim=1).sqrt()
print(
"Error getting eigenvalues, trying another method."
)
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'):
eigs, _ = torch.linalg.eig(stats)
eigs = eigs.abs()
else:
eigs, _ = torch.eig(stats)
eigs = eigs.norm(dim=1)
stats = eigs.sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance
if stats_type in ["rms", "stddev"]:
@ -569,11 +579,10 @@ def attach_diagnostics(
)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if o.dtype in (torch.float32, torch.float16, torch.float64):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(
o, class_name=get_class_name(_module)
)
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
class_name=get_class_name(_module))
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
if isinstance(_output, tuple) and len(_output) == 1:
_output = _output[0]
@ -587,11 +596,9 @@ def attach_diagnostics(
)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if o.dtype in (torch.float32, torch.float16, torch.float64):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
o, class_name=get_class_name(_module)
)
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
class_name=get_class_name(_module))
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)

View File

@ -498,7 +498,7 @@ def store_transcripts(
Returns:
Return None.
"""
with open(filename, "w") as f:
with open(filename, "w", encoding="utf8") as f:
for cut_id, ref, hyp in texts:
if char_level:
ref = list("".join(ref))
@ -523,7 +523,7 @@ def store_transcripts_and_timestamps(
Returns:
Return None.
"""
with open(filename, "w") as f:
with open(filename, "w", encoding="utf8") as f:
for cut_id, ref, hyp, time_ref, time_hyp in texts:
print(f"{cut_id}:\tref={ref}", file=f)
print(f"{cut_id}:\thyp={hyp}", file=f)
@ -1447,7 +1447,7 @@ def get_parameter_groups_with_lrs(
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of
named-parameters; we can, if needed, create a version without the names).
It provides a way to specifiy learning-rate scales inside the module, so that if
It provides a way to specify learning-rate scales inside the module, so that if
any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will
scale the LR of any parameters inside that module or its submodules. Note: you
can set module parameters outside the __init__ function, e.g.:
@ -1607,10 +1607,10 @@ def tokenize_by_bpe_model(
chars = pattern.split(txt.upper())
mix_chars = [w for w in chars if len(w.strip()) > 0]
for ch_or_w in mix_chars:
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
# ch_or_w is a single CJK character(i.e., "你"), do nothing.
if pattern.fullmatch(ch_or_w) is not None:
tokens.append(ch_or_w)
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
# ch_or_w contains non-CJK characters(i.e., " IT'S OKAY "),
# encode ch_or_w using bpe_model.
else:
for p in sp.encode_as_pieces(ch_or_w):
@ -1624,7 +1624,7 @@ def tokenize_by_CJK_char(line: str) -> str:
"""
Tokenize a line of text with CJK char.
Note: All return charaters will be upper case.
Note: All return characters will be upper case.
Example:
input = "你好世界是 hello world 的中文"
@ -1917,7 +1917,7 @@ def parse_bpe_timestamps_and_texts(
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful). Its attribtutes `labels` and `aux_labels`
be meaningful). Its attributes `labels` and `aux_labels`
are both BPE tokens.
sp:
The BPE model.
@ -1977,7 +1977,7 @@ def parse_timestamps_and_texts(
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful). Attribtute `labels` is the prediction unit,
be meaningful). Attribute `labels` is the prediction unit,
e.g., phone or BPE tokens. Attribute `aux_labels` is the word index.
word_table:
The word symbol table.
@ -2045,7 +2045,7 @@ def parse_fsa_timestamps_and_texts(
) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
"""Parse timestamps (in seconds) and texts for given decoded fsa paths.
Currently it supports two cases:
(1) ctc-decoding, the attribtutes `labels` and `aux_labels`
(1) ctc-decoding, the attributes `labels` and `aux_labels`
are both BPE tokens. In this case, sp should be provided.
(2) HLG-based 1best, the attribtute `labels` is the prediction unit,
e.g., phone or BPE tokens; attribute `aux_labels` is the word index.