Commit more scripts for wenetspeech kws recipe

This commit is contained in:
pkufool 2024-02-02 12:18:06 +08:00
parent 4b3356307a
commit 8b65f4138b
10 changed files with 2353 additions and 147 deletions

View File

@ -158,6 +158,13 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
add_training_arguments(parser)
add_model_arguments(parser)
add_finetune_arguments(parser)

View File

@ -407,10 +407,3 @@ class WenetSpeechAsrDataModule:
def test_meeting_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_MEETING cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
@lru_cache()
def test_open_commands_cuts(self) -> CutSet:
logging.info("About to get open commands cuts")
return load_manifest_lazy(
self.args.manifest_dir / "open-commands-cn_cuts_test.jsonl.gz"
)

1
egs/wenetspeech/KWS/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -1,4 +1,5 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2024 Xiaomi Corporation (Author: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -409,8 +410,50 @@ class WenetSpeechAsrDataModule:
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
@lru_cache()
def test_open_commands_cuts(self) -> CutSet:
logging.info("About to get open commands cuts")
def cn_speech_commands_small_cuts(self) -> CutSet:
logging.info("About to get cn speech commands small cuts")
return load_manifest_lazy(
self.args.manifest_dir / "open-commands-cn_cuts_test.jsonl.gz"
self.args.manifest_dir / "cn_speech_commands_cuts_small.jsonl.gz"
)
@lru_cache()
def cn_speech_commands_large_cuts(self) -> CutSet:
logging.info("About to get cn speech commands large cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cn_speech_commands_cuts_large.jsonl.gz"
)
@lru_cache()
def nihaowenwen_dev_cuts(self) -> CutSet:
logging.info("About to get nihaowenwen dev cuts")
return load_manifest_lazy(
self.args.manifest_dir / "nihaowenwen_cuts_dev.jsonl.gz"
)
@lru_cache()
def nihaowenwen_test_cuts(self) -> CutSet:
logging.info("About to get nihaowenwen test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "nihaowenwen_cuts_test.jsonl.gz"
)
@lru_cache()
def nihaowenwen_train_cuts(self) -> CutSet:
logging.info("About to get nihaowenwen train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "nihaowenwen_cuts_train.jsonl.gz"
)
@lru_cache()
def xiaoyun_clean_cuts(self) -> CutSet:
logging.info("About to get xiaoyun clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "xiaoyun_cuts_clean.jsonl.gz"
)
@lru_cache()
def xiaoyun_noisy_cuts(self) -> CutSet:
logging.info("About to get xiaoyun noisy cuts")
return load_manifest_lazy(
self.args.manifest_dir / "xiaoyun_cuts_noisy.jsonl.gz"
)

View File

@ -178,16 +178,47 @@ def get_parser():
)
parser.add_argument(
"--keyword-file",
"--keywords-file",
type=str,
help="File contains keywords.",
)
parser.add_argument(
"--keyword-score",
"--test-set",
type=str,
default="small",
help="small or large",
)
parser.add_argument(
"--keywords-score",
type=float,
default=0.75,
help="The threshold (probability) to boost the keyword.",
default=1.5,
help="""
The default boosting score (token level) for keywords. it will boost the
paths that match keywords to make them survive beam search.
""",
)
parser.add_argument(
"--keywords-threshold",
type=float,
default=0.35,
help="The default threshold (probability) to trigger the keyword.",
)
parser.add_argument(
"--keywords-version",
type=str,
default="",
help="The keywords configuration version, just to save results to different files.",
)
parser.add_argument(
"--num-tailing-blanks",
type=int,
default=1,
help="The number of tailing blanks should have after hitting one keyword.",
)
add_model_arguments(parser)
@ -261,7 +292,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context_graph=kws_graph,
keywords_graph=kws_graph,
beam=params.beam_size,
num_tailing_blanks=8,
)
@ -288,6 +319,7 @@ def decode_dataset(
lexicon: Lexicon,
kws_graph: ContextGraph,
keywords: Set[str],
test_only_keywords: bool,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -342,32 +374,62 @@ def decode_dataset(
hyp_words = [x[0] for x in hyp_words]
this_batch.append((cut_id, ref_words, list("".join(hyp_words))))
hyp_set = set(hyp_words)
hyp_str = " | ".join(hyp_words)
if len(hyp_words) > 1:
logging.warning(
f"Cut {cut_id} triggers more than one keywords : {hyp_words},"
f"please check the transcript to see if it really has more "
f"than one keywords, if so consider splitting this audio and"
f"keep only one keyword for each audio."
)
hyp_str = " | ".join(
hyp_words
) # The triggered keywords for this utterance.
TP = False
FP = False
for x in hyp_set:
assert x in keywords, x
if x in ref_text and x in keywords:
metric["all"].TP += 1
assert x in keywords, x # can only trigger keywords
if (test_only_keywords and x == ref_text) or (
not test_only_keywords and x in ref_text
):
TP = True
metric[x].TP += 1
metric[x].TP_list.append(f"({ref_text} -> {x})")
if x not in ref_text and x in keywords:
metric["all"].FP += 1
if (test_only_keywords and x != ref_text) or (
not test_only_keywords and x not in ref_text
):
FP = True
metric[x].FP += 1
metric[x].FP_list.append(f"({ref_text} -> {x}/{cut_id})")
metric[x].FP_list.append(f"({ref_text} -> {x})")
if TP:
metric["all"].TP += 1
if FP:
metric["all"].FP += 1
TN = True # all keywords are true negative then the summery is true negative.
FN = False
for x in keywords:
if x not in ref_text and x not in hyp_set:
metric["all"].TN += 1
metric[x].TN += 1
continue
if x in ref_text:
TN = False
if (test_only_keywords and x == ref_text) or (
not test_only_keywords and x in ref_text
):
fn = True
for y in hyp_set:
if y in ref_text:
if (test_only_keywords and y == ref_text) or (
not test_only_keywords and y in ref_text
):
fn = False
break
if fn and ref_text.endswith(x):
metric["all"].FN += 1
if fn:
FN = True
metric[x].FN += 1
metric[x].FN_list.append(f"({ref_text} -> {hyp_str}/{cut_id})")
metric[x].FN_list.append(f"({ref_text} -> {hyp_str})")
if TN:
metric["all"].TN += 1
if FN:
metric["all"].FN += 1
results.extend(this_batch)
@ -399,16 +461,17 @@ def save_results(
metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt"
print_s = ""
with open(metric_filename, "w") as of:
width = 10
for key, item in sorted(
metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True
):
acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN)
precision = (item.TP + 1) / (item.TP + item.FP + 1)
recall = (item.TP + 1) / (item.TP + item.FN + 1)
fpr = (item.FP + 1) / (item.FP + item.TN + 1)
precision = (
0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP)
)
recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN)
fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN)
s = f"{key}:\n"
s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n"
s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n"
@ -417,12 +480,14 @@ def save_results(
s += f"\tRecall(PPR): {recall:.3f}\n"
s += f"\tFPR: {fpr:.3f}\n"
s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n"
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
s += f"\tFN list: {' # '.join(item.FN_list)}\n"
if key != "all":
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
s += f"\tFN list: {' # '.join(item.FN_list)}\n"
of.write(s + "\n")
if key == "all":
logging.info(s)
of.write(f"\n\n{params.keywords_config}")
logging.info("Wrote metric stats to {}".format(metric_filename))
@ -439,6 +504,7 @@ def main():
params.res_dir = params.exp_dir / "kws"
params.suffix = params.test_set
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
@ -454,9 +520,12 @@ def main():
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
params.suffix += f"-keyword-score-{params.keyword_score}"
params.suffix += f"-score-{params.keywords_score}"
params.suffix += f"-threshold-{params.keywords_threshold}"
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
if params.blank_penalty != 0:
params.suffix += f"-blank-penalty-{params.blank_penalty}"
params.suffix += f"-version-{params.keywords_version}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -473,18 +542,30 @@ def main():
logging.info(params)
keywords = []
keywords_id = []
with open(params.keyword_file, "r") as f:
phrases = []
token_ids = []
keywords_scores = []
keywords_thresholds = []
keywords_config = []
with open(params.keywords_file, "r") as f:
for line in f.readlines():
keywords_config.append(line)
score = 0
kws = line.strip().upper().split()
if kws[-1][0] == ":":
score = float(kws[-1][1:])
kws = kws[0:-1]
threshold = 0
keyword = []
words = line.strip().upper().split()
for word in words:
word = word.strip()
if word[0] == ":":
score = float(word[1:])
continue
if word[0] == "#":
threshold = float(word[1:])
continue
keyword.append(word)
keyword = "".join(keyword)
tmp_ids = []
kws = "".join(kws)
kws_py = text_to_pinyin(kws, mode=params.pinyin_type)
kws_py = text_to_pinyin(keyword, mode=params.pinyin_type)
for k in kws_py:
if k in lexicon.token_table:
tmp_ids.append(lexicon.token_table[k])
@ -493,11 +574,23 @@ def main():
tmp_ids = []
break
if tmp_ids:
logging.info(f"Adding keyword : {kws}")
keywords.append(kws)
keywords_id.append((tmp_ids, score, kws))
kws_graph = ContextGraph(context_score=params.keyword_score)
kws_graph.build(keywords_id)
logging.info(f"Adding keyword : {keyword}")
phrases.append(keyword)
token_ids.append(tmp_ids)
keywords_scores.append(score)
keywords_thresholds.append(threshold)
params.keywords_config = "".join(keywords_config)
kws_graph = ContextGraph(
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
)
kws_graph.build(
token_ids=token_ids,
phrases=phrases,
scores=keywords_scores,
ac_thresholds=keywords_thresholds,
)
keywords = set(phrases)
logging.info("About to create model")
model = get_model(params)
@ -597,21 +690,51 @@ def main():
)
return T > 0
def select_keywords(c: Cut):
text = c.supervisions[0].text.strip()
return text in keywords
commands_cuts = wenetspeech.test_open_commands_cuts()
commands_cuts = commands_cuts.filter(select_keywords)
commands_cuts = commands_cuts.filter(remove_short_utt)
commands_dl = wenetspeech.test_dataloaders(commands_cuts)
test_net_cuts = wenetspeech.test_net_cuts()
test_net_cuts = test_net_cuts.filter(remove_short_utt)
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
test_sets = ["COMMANDS"] # , "TEST_NET"]
test_dls = [commands_dl] # , test_net_dl]
cn_commands_small_cuts = wenetspeech.cn_speech_commands_small_cuts()
cn_commands_small_cuts = cn_commands_small_cuts.filter(remove_short_utt)
cn_commands_small_dl = wenetspeech.test_dataloaders(cn_commands_small_cuts)
cn_commands_large_cuts = wenetspeech.cn_speech_commands_large_cuts()
cn_commands_large_cuts = cn_commands_large_cuts.filter(remove_short_utt)
cn_commands_large_dl = wenetspeech.test_dataloaders(cn_commands_large_cuts)
nihaowenwen_test_cuts = wenetspeech.nihaowenwen_test_cuts()
nihaowenwen_test_cuts = nihaowenwen_test_cuts.filter(remove_short_utt)
nihaowenwen_test_dl = wenetspeech.test_dataloaders(nihaowenwen_test_cuts)
xiaoyun_clean_cuts = wenetspeech.xiaoyun_clean_cuts()
xiaoyun_clean_cuts = xiaoyun_clean_cuts.filter(remove_short_utt)
xiaoyun_clean_dl = wenetspeech.test_dataloaders(xiaoyun_clean_cuts)
xiaoyun_noisy_cuts = wenetspeech.xiaoyun_noisy_cuts()
xiaoyun_noisy_cuts = xiaoyun_noisy_cuts.filter(remove_short_utt)
xiaoyun_noisy_dl = wenetspeech.test_dataloaders(xiaoyun_noisy_cuts)
test_sets = []
test_dls = []
if params.test_set == "large":
test_sets.append("cn_commands_large")
test_dls.append(cn_commands_large_dl)
else:
assert params.test_set == "small", params.test_set
test_sets += [
"cn_commands_small",
"nihaowenwen",
"xiaoyun_clean",
"xiaoyun_noisy",
"test_net",
]
test_dls += [
cn_commands_small_dl,
nihaowenwen_test_dl,
xiaoyun_clean_dl,
xiaoyun_noisy_dl,
test_net_dl,
]
for test_set, test_dl in zip(test_sets, test_dls):
results, metric = decode_dataset(
@ -620,7 +743,8 @@ def main():
model=model,
lexicon=lexicon,
kws_graph=kws_graph,
keywords=set(keywords),
keywords=keywords,
test_only_keywords="test_net" not in test_set,
)
save_results(

View File

@ -0,0 +1,526 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
Note: This is a example for librispeech dataset, if you are using different
dataset, you should change the argument values according to your dataset.
(1) Export to torchscript model using torch.jit.script()
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("jit_script.pt")`.
Check ./jit_pretrained.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
Check ./jit_pretrained_streaming.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
- For non-streaming model:
To use the generated file with `zipformer/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./zipformer/decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
- For streaming model:
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
# simulated streaming decoding
./zipformer/decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
# chunk-wise streaming decoding
./zipformer/streaming_decode.py \
--exp-dir ./zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
- non-streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
- streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
# You will find the pre-trained models in exp dir
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named jit_script.pt.
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
class EncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Args:
features: (N, T, C)
feature_lengths: (N,)
"""
x, x_lens = self.encoder_embed(features, feature_lengths)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
class StreamingEncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
assert len(encoder.chunk_size) == 1, encoder.chunk_size
assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
self.pad_length = 7 + 2 * 3
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""Streaming forward for encoder_embed and encoder.
Args:
features: (N, T, C)
feature_lengths: (N,)
states: a list of Tensors
Returns encoder outputs, output lengths, and updated states.
"""
chunk_size = self.chunk_size
left_context_len = self.left_context_len
cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lengths,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = self.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
# if torch.cuda.is_available():
# device = torch.device("cuda", 0)
logging.info(f"device: {device}")
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
# Wrap encoder and encoder_embed as a module
if params.causal:
model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
chunk_size = model.encoder.chunk_size
left_context_len = model.encoder.left_context_len
filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
else:
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
filename = "jit_script.pt"
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
model.save(str(params.exp_dir / filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -3,6 +3,7 @@
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Yifan Yang,
# Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
@ -23,29 +24,44 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For non-streaming model training:
./zipformer/train.py \
# For non-streaming model finetuning:
./zipformer/finetune.py \
--world-size 4 \
--num-epochs 30 \
--num-epochs 10 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--max-duration 1000
# For streaming model training:
./zipformer/train.py \
# For non-streaming model finetuning with mux (original dataset):
./zipformer/finetune.py \
--world-size 4 \
--num-epochs 30 \
--num-epochs 10 \
--start-epoch 1 \
--use-mux 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--max-duration 1000
# For streaming model finetuning:
./zipformer/fintune.py \
--world-size 4 \
--num-epochs 10 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--causal 1 \
--max-duration 1000
It supports training with:
- transducer loss (default), with `--use-transducer True --use-ctc False`
- ctc loss (not recommended), with `--use-transducer False --use-ctc True`
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
# For streaming model finetuning with mux (original dataset):
./zipformer/fintune.py \
--world-size 4 \
--num-epochs 10 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--causal 1 \
--max-duration 1000
"""
@ -55,7 +71,7 @@ import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import k2
import optim
@ -63,12 +79,10 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule
from lhotse.cut import Cut
from lhotse.cut import Cut, CutSet
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import AsrModel
from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -76,7 +90,7 @@ from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import remove_checkpoints
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
@ -109,9 +123,50 @@ from train import (
set_batch_count,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def add_finetune_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--use-mux",
type=str2bool,
default=False,
help="""
Whether to adapt. If true, we will mix 5% of the new data
with 95% of the original data to fine-tune.
""",
)
parser.add_argument(
"--init-modules",
type=str,
default=None,
help="""
Modules to be initialized. It matches all parameters starting with
a specific key. The keys are given with Comma seperated. If None,
all modules will be initialised. For example, if you only want to
initialise all parameters staring with "encoder", use "encoder";
if you want to initialise parameters starting with encoder or decoder,
use "encoder,joiner".
""",
)
parser.add_argument(
"--finetune-ckpt",
type=str,
default=None,
help="Fine-tuning from which checkpoint (a path to a .pt file)",
)
parser.add_argument(
"--continue-finetune",
type=str2bool,
default=False,
help="Continue finetuning or finetune from pre-trained model",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -148,10 +203,58 @@ def get_parser():
add_training_arguments(parser)
add_model_arguments(parser)
add_finetune_arguments(parser)
return parser
def load_model_params(
ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
):
"""Load model params from checkpoint
Args:
ckpt (str): Path to the checkpoint
model (nn.Module): model to be loaded
"""
logging.info(f"Loading checkpoint from {ckpt}")
checkpoint = torch.load(ckpt, map_location="cpu")
# if module list is empty, load the whole model from ckpt
if not init_modules:
if next(iter(checkpoint["model"])).startswith("module."):
logging.info("Loading checkpoint saved by DDP")
dst_state_dict = model.state_dict()
src_state_dict = checkpoint["model"]
for key in dst_state_dict.keys():
src_key = "{}.{}".format("module", key)
dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict, strict=strict)
else:
model.load_state_dict(checkpoint["model"], strict=strict)
else:
src_state_dict = checkpoint["model"]
dst_state_dict = model.state_dict()
for module in init_modules:
logging.info(f"Loading parameters starting with prefix {module}")
src_keys = [
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
]
dst_keys = [
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
]
assert set(src_keys) == set(dst_keys) # two sets should match exactly
for key in src_keys:
dst_state_dict[key] = src_state_dict.pop(key)
model.load_state_dict(dst_state_dict, strict=strict)
return None
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
@ -160,7 +263,7 @@ def compute_loss(
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Compute loss given the model and its inputs.
Args:
params:
@ -191,10 +294,10 @@ def compute_loss(
texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts, sep="/")
y = k2.RaggedTensor(y).to(device)
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model(
simple_loss, pruned_loss, ctc_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -203,21 +306,26 @@ def compute_loss(
lm_scale=params.lm_scale,
)
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss = 0.0
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_transducer:
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
assert loss.requires_grad == is_training
@ -228,8 +336,11 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_transducer:
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
return loss, info
@ -317,8 +428,6 @@ def train_one_epoch(
tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
saved_bad_model = False
def save_bad_model(suffix: str = ""):
@ -336,10 +445,7 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params))
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
set_batch_count(model, get_adjusted_batch_count(params) + 100000)
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -359,6 +465,7 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
@ -387,7 +494,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
@ -400,7 +506,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
@ -532,14 +637,20 @@ def run(rank, world_size, args):
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
if params.continue_finetune:
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
else:
modules = params.init_modules.split(",") if params.init_modules else None
checkpoints = load_model_params(
ckpt=params.finetune_ckpt, model=model, init_modules=modules
)
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
model.to(device)
if world_size > 1:
@ -552,7 +663,7 @@ def run(rank, world_size, args):
clipping_scale=2.0,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=1.0)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
@ -568,33 +679,31 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
if params.inf_check:
register_inf_check_hooks(model)
def remove_short_utt(c: Cut):
if c.duration > 15:
return False
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
return T > 0
wenetspeech = WenetSpeechAsrDataModule(args)
train_cuts = wenetspeech.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 15.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
return True
if params.use_mux:
train_cuts = CutSet.mux(
wenetspeech.train_cuts(),
wenetspeech.nihaowenwen_train_cuts(),
weights=[0.9, 0.1],
)
else:
train_cuts = wenetspeech.nihaowenwen_train_cuts()
def encode_text(c: Cut):
# Text normalize for each sample
@ -605,7 +714,7 @@ def run(rank, world_size, args):
c.supervisions[0].text = text
return c
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_cuts = train_cuts.filter(remove_short_utt)
train_cuts = train_cuts.map(encode_text)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
@ -619,19 +728,19 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = wenetspeech.valid_cuts()
valid_cuts = wenetspeech.nihaowenwen_dev_cuts()
valid_cuts = valid_cuts.filter(remove_short_utt)
valid_cuts = valid_cuts.map(encode_text)
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# graph_compiler=graph_compiler,
# params=params,
# )
pass
if not params.print_diagnostics and params.scan_for_oom_batches:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
@ -689,7 +798,6 @@ def main():
parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.lang_dir = Path(args.lang_dir)
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
@ -701,4 +809,6 @@ def main():
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling_converter.py

File diff suppressed because it is too large Load Diff

View File

@ -1609,9 +1609,9 @@ def text_to_pinyin(
The input Chinese text.
mode:
The style of the output pinyin, should be:
full_with_tone : zhong1 guo2
full_with_tone : zhōng guó
full_no_tone : zhong guo
partial_with_tone : zh ong1 g uo2
partial_with_tone : zh ōng g
partial_no_tone : zh ong g uo
errors:
How to handle the characters (latin) that has no pinyin.