Merge branch 'k2-fsa:master' into dev/cv-zipformer

This commit is contained in:
zr_jin 2024-03-20 17:00:36 +08:00 committed by GitHub
commit c2740030df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 10731 additions and 197 deletions

View File

@ -36,7 +36,9 @@ RUN pip install --no-cache-dir \
\
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \
cython \
dill \
espnet_tts_frontend \
graphviz \
kaldi-decoder \
kaldi_native_io \
@ -45,13 +47,15 @@ RUN pip install --no-cache-dir \
kaldilm \
matplotlib \
multi_quantization \
numba \
numpy \
onnx \
onnxmltools \
onnxruntime \
piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html \
pypinyin==0.50.0 \
pytest \
sentencepiece>=0.1.96 \
pypinyin==0.50.0 \
six \
tensorboard \
typeguard

View File

@ -45,7 +45,7 @@ def get_torchaudio_version(torch_version):
def get_matrix():
k2_version = "1.24.4.dev20240223"
kaldifeat_version = "1.25.4.dev20240223"
version = "20240223"
version = "20240318"
python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"]
torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
torch_version += ["2.2.0", "2.2.1"]

View File

@ -64,6 +64,46 @@ function run_diagnostics() {
--print-diagnostics 1
}
function test_streaming_zipformer_ctc_hlg() {
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
rm $repo/exp-ctc-rnnt-small/*.onnx
ls -lh $repo/exp-ctc-rnnt-small
# export models to onnx
./zipformer/export-onnx-streaming-ctc.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 3 \
--exp-dir $repo/exp-ctc-rnnt-small \
--causal 1 \
--use-ctc 1 \
--chunk-size 16 \
--left-context-frames 128 \
\
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192
ls -lh $repo/exp-ctc-rnnt-small
for wav in 0.wav 1.wav 8k.wav; do
python3 ./zipformer/onnx_pretrained_ctc_HLG_streaming.py \
--nn-model $repo/exp-ctc-rnnt-small/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
--words $repo/data/lang_bpe_500/words.txt \
--HLG $repo/data/lang_bpe_500/HLG.fst \
$repo/test_wavs/$wav
done
rm -rf $repo
}
function test_pruned_transducer_stateless_2022_03_12() {
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
@ -1577,6 +1617,7 @@ function test_transducer_bpe_500_2021_12_23() {
prepare_data
run_diagnostics
test_streaming_zipformer_ctc_hlg
test_pruned_transducer_stateless_2022_03_12
test_pruned_transducer_stateless2_2022_04_29
test_pruned_transducer_stateless3_2022_04_29

View File

@ -419,7 +419,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
store_transcripts(filename=recog_path, texts=results, char_level=True)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
@ -432,7 +432,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
f,
f"{test_set_name}-{key}",
results_char,
enable_log=enable_log,
compute_CER=True,
)
test_set_wers[key] = wer

View File

@ -431,7 +431,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
store_transcripts(filename=recog_path, texts=results, char_level=True)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
@ -444,7 +444,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
f,
f"{test_set_name}-{key}",
results_char,
enable_log=enable_log,
compute_CER=True,
)
test_set_wers[key] = wer

View File

@ -390,7 +390,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@ -402,7 +402,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
)
test_set_wers[key] = wer

View File

@ -526,7 +526,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@ -538,7 +538,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
)
test_set_wers[key] = wer

View File

@ -444,7 +444,7 @@ def save_results(
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
store_transcripts(filename=recog_path, texts=results_char)
store_transcripts(filename=recog_path, texts=results_char, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@ -452,7 +452,11 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
)
test_set_wers[key] = wer

View File

@ -581,7 +581,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@ -594,7 +594,11 @@ def save_results(
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
)
test_set_wers[key] = wer

View File

@ -492,7 +492,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@ -500,7 +500,11 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
)
test_set_wers[key] = wer

View File

@ -278,7 +278,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@ -289,7 +289,13 @@ def save_results(
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
wer = write_error_stats(
f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))

View File

@ -327,7 +327,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
store_transcripts(filename=recog_path, texts=results, char_level=True)
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
@ -338,7 +338,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
)
test_set_wers[key] = wer

View File

@ -372,7 +372,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@ -384,7 +384,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
)
test_set_wers[key] = wer

View File

@ -376,7 +376,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@ -388,7 +388,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
)
test_set_wers[key] = wer

View File

@ -358,7 +358,7 @@ def save_results(
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
store_transcripts(filename=recog_path, texts=results, char_level=True)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
@ -373,7 +373,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
f,
f"{test_set_name}-{key}",
results_char,
enable_log=enable_log,
compute_CER=True,
)
test_set_wers[key] = wer

View File

@ -793,7 +793,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -560,7 +560,7 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@ -570,7 +570,11 @@ def save_results(
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
)
test_set_wers[key] = wer

View File

@ -1036,7 +1036,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1120,7 +1120,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1058,7 +1058,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -36,6 +36,7 @@ The following table lists the differences among them.
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters |
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | Finetune Zipformer with LoRA |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -479,18 +479,14 @@ class LibriSpeechAsrDataModule:
@lru_cache()
def gigaspeech_subset_small_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech subset-S cuts")
return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz")
return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz")
@lru_cache()
def gigaspeech_dev_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech dev cuts")
return load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
@lru_cache()
def gigaspeech_test_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")

View File

@ -32,7 +32,7 @@ This script exports a CTC model from PyTorch to ONNX.
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
--left-context-frames 64 \
--left-context-frames 128 \
--use-ctc 1
The --chunk-size in training is "16,32,64,-1", so we select one of them
@ -41,7 +41,7 @@ whose value is "64,128,256,-1".
It will generate the following file inside $repo/exp:
- ctc-epoch-99-avg-1-chunk-16-left-64.onnx
- ctc-epoch-99-avg-1-chunk-16-left-128.onnx
See ./onnx_pretrained-streaming-ctc.py for how to use the exported ONNX models.
"""

View File

@ -48,7 +48,7 @@ popd
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
--left-context-frames 64
--left-context-frames 128
The --chunk-size in training is "16,32,64,-1", so we select one of them
(excluding -1) during streaming export. The same applies to `--left-context`,
@ -56,9 +56,9 @@ whose value is "64,128,256,-1".
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1-chunk-16-left-64.onnx
- decoder-epoch-99-avg-1-chunk-16-left-64.onnx
- joiner-epoch-99-avg-1-chunk-16-left-64.onnx
- encoder-epoch-99-avg-1-chunk-16-left-128.onnx
- decoder-epoch-99-avg-1-chunk-16-left-128.onnx
- joiner-epoch-99-avg-1-chunk-16-left-128.onnx
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
"""
@ -333,6 +333,7 @@ def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
feature_dim: int = 80,
) -> None:
encoder_model.encoder.__class__.forward = (
encoder_model.encoder.__class__.streaming_forward
@ -343,7 +344,7 @@ def export_encoder_model_onnx(
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
T = decode_chunk_len + encoder_model.pad_length
x = torch.rand(1, T, 80, dtype=torch.float32)
x = torch.rand(1, T, feature_dim, dtype=torch.float32)
init_state = encoder_model.get_init_states()
num_encoders = len(encoder_model.encoder.encoder_dim)
logging.info(f"num_encoders: {num_encoders}")
@ -724,6 +725,7 @@ def main():
encoder,
encoder_filename,
opset_version=opset_version,
feature_dim=params.feature_dim,
)
logging.info(f"Exported encoder to {encoder_filename}")

View File

@ -0,0 +1,439 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script loads ONNX models exported by ./export-onnx-streaming-ctc.py
and uses them to decode waves.
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "exp-ctc-rnnt-small/*.pt"
git lfs pull --include "data/lang_bpe_500/words.txt"
git lfs pull --include "data/lang_bpe_500/HLG.fst"
popd
2. Export the model to ONNX
./zipformer/export-onnx-streaming-ctc.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 3 \
--exp-dir $repo/exp-ctc-rnnt-small \
--causal 1 \
--use-ctc 1 \
--chunk-size 16 \
--left-context-frames 128 \
\
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192
It will generate the following 2 files inside $repo/exp-ctc-rnnt-small:
- ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx
- ctc-epoch-30-avg-3-chunk-16-left-128.onnx
You can use either the ``int8.onnx`` model or just the ``.onnx`` model.
3. Run this file with the exported ONNX models
python3 ./zipformer/onnx_pretrained_ctc_HLG_streaming.py \
--nn-model $repo/exp-ctc-rnnt-small/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
--words $repo/data/lang_bpe_500/words.txt \
--HLG $repo/data/lang_bpe_500/HLG.fst \
$repo/test_wavs/0.wav
Note: Even though this script only supports decoding a single file,
the exported ONNX models do support batch processing.
Note: HLG.fst is generated directly from ../local/prepare_lang_fst.py
"""
import argparse
import logging
from typing import Dict, List, Tuple
import k2
import kaldifst
import numpy as np
import onnxruntime as ort
import torch
import torchaudio
from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model",
type=str,
required=True,
help="Path to the onnx model. ",
)
parser.add_argument(
"--words",
type=str,
required=True,
help="""Path to words.txt.""",
)
parser.add_argument(
"--HLG",
type=str,
required=True,
help="""Path to HLG.fst.""",
)
parser.add_argument(
"sound_file",
type=str,
help="The input sound file to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. ",
)
return parser
class OnnxModel:
def __init__(
self,
model_filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.init_model(model_filename)
def init_model(self, model_filename: str):
self.model = ort.InferenceSession(
model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.init_states()
def init_states(self, batch_size: int = 1):
meta = self.model.get_modelmeta().custom_metadata_map
logging.info(f"meta={meta}")
model_type = meta["model_type"]
assert model_type == "zipformer2", model_type
decode_chunk_len = int(meta["decode_chunk_len"])
T = int(meta["T"])
num_encoder_layers = meta["num_encoder_layers"]
encoder_dims = meta["encoder_dims"]
cnn_module_kernels = meta["cnn_module_kernels"]
left_context_len = meta["left_context_len"]
query_head_dims = meta["query_head_dims"]
value_head_dims = meta["value_head_dims"]
num_heads = meta["num_heads"]
def to_int_list(s):
return list(map(int, s.split(",")))
num_encoder_layers = to_int_list(num_encoder_layers)
encoder_dims = to_int_list(encoder_dims)
cnn_module_kernels = to_int_list(cnn_module_kernels)
left_context_len = to_int_list(left_context_len)
query_head_dims = to_int_list(query_head_dims)
value_head_dims = to_int_list(value_head_dims)
num_heads = to_int_list(num_heads)
logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"T: {T}")
logging.info(f"num_encoder_layers: {num_encoder_layers}")
logging.info(f"encoder_dims: {encoder_dims}")
logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
logging.info(f"left_context_len: {left_context_len}")
logging.info(f"query_head_dims: {query_head_dims}")
logging.info(f"value_head_dims: {value_head_dims}")
logging.info(f"num_heads: {num_heads}")
num_encoders = len(num_encoder_layers)
self.states = []
for i in range(num_encoders):
num_layers = num_encoder_layers[i]
key_dim = query_head_dims[i] * num_heads[i]
embed_dim = encoder_dims[i]
nonlin_attn_head_dim = 3 * embed_dim // 4
value_dim = value_head_dims[i] * num_heads[i]
conv_left_pad = cnn_module_kernels[i] // 2
for layer in range(num_layers):
cached_key = torch.zeros(
left_context_len[i], batch_size, key_dim
).numpy()
cached_nonlin_attn = torch.zeros(
1, batch_size, left_context_len[i], nonlin_attn_head_dim
).numpy()
cached_val1 = torch.zeros(
left_context_len[i], batch_size, value_dim
).numpy()
cached_val2 = torch.zeros(
left_context_len[i], batch_size, value_dim
).numpy()
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
self.states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
self.states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
self.states.append(processed_lens)
self.num_encoders = num_encoders
self.segment = T
self.offset = decode_chunk_len
def _build_model_input_output(
self,
x: torch.Tensor,
) -> Tuple[Dict[str, np.ndarray], List[str]]:
model_input = {"x": x.numpy()}
model_output = ["log_probs"]
def build_inputs_outputs(tensors, i):
assert len(tensors) == 6, len(tensors)
# (downsample_left, batch_size, key_dim)
name = f"cached_key_{i}"
model_input[name] = tensors[0]
model_output.append(f"new_{name}")
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
name = f"cached_nonlin_attn_{i}"
model_input[name] = tensors[1]
model_output.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val1_{i}"
model_input[name] = tensors[2]
model_output.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val2_{i}"
model_input[name] = tensors[3]
model_output.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv1_{i}"
model_input[name] = tensors[4]
model_output.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv2_{i}"
model_input[name] = tensors[5]
model_output.append(f"new_{name}")
for i in range(len(self.states[:-2]) // 6):
build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
# (batch_size, channels, left_pad, freq)
name = "embed_states"
embed_states = self.states[-2]
model_input[name] = embed_states
model_output.append(f"new_{name}")
# (batch_size,)
name = "processed_lens"
processed_lens = self.states[-1]
model_input[name] = processed_lens
model_output.append(f"new_{name}")
return model_input, model_output
def _update_states(self, states: List[np.ndarray]):
self.states = states
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
Returns:
Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size)
where T' is usually equal to ((T-7)//2 - 3)//2
"""
model_input, model_output_names = self._build_model_input_output(x)
out = self.model.run(model_output_names, model_input)
self._update_states(out[1:])
return torch.from_numpy(out[0])
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
if sample_rate != expected_sample_rate:
logging.info(f"Resample {sample_rate} to {expected_sample_rate}")
wave = torchaudio.functional.resample(
wave,
orig_freq=sample_rate,
new_freq=expected_sample_rate,
)
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
word_table = k2.SymbolTable.from_file(args.words)
model = OnnxModel(model_filename=args.nn_model)
sample_rate = 16000
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor()
logging.info(f"Reading sound files: {args.sound_file}")
waves = read_sound_files(
filenames=[args.sound_file],
expected_sample_rate=sample_rate,
)[0]
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
wave_samples = torch.cat([waves, tail_padding])
num_processed_frames = 0
segment = model.segment
offset = model.offset
logging.info(f"Loading HLG from {args.HLG}")
HLG = kaldifst.StdVectorFst.read(args.HLG)
decoder_opts = FasterDecoderOptions(max_active=3000)
decoder = FasterDecoder(HLG, decoder_opts)
decoder.init_decoding()
chunk = int(1 * sample_rate) # 1 second
start = 0
n = 0
while start < wave_samples.numel():
end = min(start + chunk, wave_samples.numel())
# simulate streaming
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
frames = torch.cat(frames, dim=0)
frames = frames.unsqueeze(0)
log_probs = model(frames)
log_probs = log_probs.squeeze(0).cpu().numpy()
decodable = DecodableCtc(log_probs, offset=n)
n += log_probs.shape[0]
num_processed_frames += offset
decoder.advance_decoding(decodable)
if not decoder.reached_final():
logging.info(f"Failed to decode {args.sound_file}")
return
ok, best_path = decoder.get_best_path()
(
ok,
isymbols_out,
osymbols_out,
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
logging.info(f"Failed to get linear symbol sequence for {args.sound_file}")
return
hyps = " ".join([word_table[i] for i in osymbols_out]).lower()
logging.info(f"\n{args.sound_file}\n{hyps}")
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../tdnn_lstm_ctc/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../zipformer/decoder.py

View File

@ -0,0 +1 @@
../transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,543 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
Note: This is a example for librispeech dataset, if you are using different
dataset, you should change the argument values according to your dataset.
(1) Export to torchscript model using torch.jit.script()
- For non-streaming model:
./zipformer_lora/export.py \
--exp-dir ./zipformer_lora/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("jit_script.pt")`.
Check ./jit_pretrained.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
- For streaming model:
./zipformer_lora/export.py \
--exp-dir ./zipformer_lora/exp \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
Check ./jit_pretrained_streaming.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
- For non-streaming model:
./zipformer_lora/export.py \
--exp-dir ./zipformer_lora/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
- For streaming model:
./zipformer_lora/export.py \
--exp-dir ./zipformer_lora/exp \
--causal 1 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
- For non-streaming model:
To use the generated file with `zipformer_lora/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./zipformer_lora/decode.py \
--exp-dir ./zipformer_lora/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
- For streaming model:
To use the generated file with `zipformer_lora/decode.py` and `zipformer_lora/streaming_decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
# simulated streaming decoding
./zipformer_lora/decode.py \
--exp-dir ./zipformer_lora/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
# chunk-wise streaming decoding
./zipformer_lora/streaming_decode.py \
--exp-dir ./zipformer_lora/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
- non-streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
- streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
# You will find the pre-trained models in exp dir
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import torch
from finetune import add_finetune_arguments, add_model_arguments, get_model, get_params
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer_lora/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named jit_script.pt.
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
add_finetune_arguments(parser)
return parser
class EncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Args:
features: (N, T, C)
feature_lengths: (N,)
"""
x, x_lens = self.encoder_embed(features, feature_lengths)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
class StreamingEncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
assert len(encoder.chunk_size) == 1, encoder.chunk_size
assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
self.pad_length = 7 + 2 * 3
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""Streaming forward for encoder_embed and encoder.
Args:
features: (N, T, C)
feature_lengths: (N,)
states: a list of Tensors
Returns encoder outputs, output lengths, and updated states.
"""
chunk_size = self.chunk_size
left_context_len = self.left_context_len
cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lengths,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = self.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
# if torch.cuda.is_available():
# device = torch.device("cuda", 0)
logging.info(f"device: {device}")
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
# merge the LoRA weights
model.eval()
params.use_lora = False
base_model = get_model(params)
new_state_dict = {}
state_dict = model.state_dict()
param_names = base_model.state_dict().keys()
for k in param_names:
assert k in state_dict.keys()
new_state_dict[k] = state_dict[k]
base_model.load_state_dict(new_state_dict, strict=True)
model = base_model
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
# Wrap encoder and encoder_embed as a module
if params.causal:
model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
chunk_size = model.encoder.chunk_size
left_context_len = model.encoder.left_context_len
filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
else:
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
filename = "jit_script.pt"
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
model.save(str(params.exp_dir / filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../zipformer/joiner.py

View File

@ -0,0 +1 @@
../zipformer/model.py

View File

@ -0,0 +1 @@
../zipformer/optim.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../zipformer/scaling_converter.py

View File

@ -0,0 +1 @@
../zipformer/subsampling.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -824,7 +824,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -115,9 +115,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
cat ./data/lang_bpe_500/transcript_words.txt \
>> $lang_dir/text_words_segmentation
cat ./data/lang_char/text \
>> $lang_dir/text
fi
cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decode_stream.py

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/streaming_decode.py

View File

@ -0,0 +1,869 @@
#!/usr/bin/env python3
# Copyright 2022-2024 Xiaomi Corporation (Authors: Wei Kang,
# Fangjun Kuang,
# Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./zipformer/streaming_decode.py \
--epoch 28 \
--avg 15 \
--causal 1 \
--chunk-size 32 \
--left-context-frames 256 \
--exp-dir ./zipformer/exp \
--decoding-method greedy_search \
--num-decode-streams 2000
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
from asr_datamodule import AsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""",
)
parser.add_argument(
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
add_model_arguments(parser)
return parser
def get_init_states(
model: nn.Module,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = model.encoder.get_init_states(batch_size, device)
embed_states = model.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""Stack list of zipformer states that correspond to separate utterances
into a single emformer state, so that it can be used as an input for
zipformer when those utterances are formed into a batch.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance. For element-n,
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
cached_val2, cached_conv1, cached_conv2).
state_list[n][-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
state_list[n][-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Note:
It is the inverse of :func:`unstack_states`.
"""
batch_size = len(state_list)
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
tot_num_layers = (len(state_list[0]) - 2) // 6
batch_states = []
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key = torch.cat(
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn = torch.cat(
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1 = torch.cat(
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2 = torch.cat(
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1 = torch.cat(
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2 = torch.cat(
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
)
batch_states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
cached_embed_left_pad = torch.cat(
[state_list[i][-2] for i in range(batch_size)], dim=0
)
batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
batch_states.append(processed_lens)
return batch_states
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
"""Unstack the zipformer state corresponding to a batch of utterances
into a list of states, where the i-th entry is the state from the i-th
utterance in the batch.
Note:
It is the inverse of :func:`stack_states`.
Args:
batch_states: A list of cached tensors of all encoder layers. For layer-i,
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
cached_conv1, cached_conv2).
state_list[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Returns:
state_list: A list of list. Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance.
"""
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
tot_num_layers = (len(batch_states) - 2) // 6
processed_lens = batch_states[-1]
batch_size = processed_lens.shape[0]
state_list = [[] for _ in range(batch_size)]
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1_list = batch_states[layer_offset + 2].chunk(
chunks=batch_size, dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2_list = batch_states[layer_offset + 3].chunk(
chunks=batch_size, dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1_list = batch_states[layer_offset + 4].chunk(
chunks=batch_size, dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2_list = batch_states[layer_offset + 5].chunk(
chunks=batch_size, dim=0
)
for i in range(batch_size):
state_list[i] += [
cached_key_list[i],
cached_nonlin_attn_list[i],
cached_val1_list[i],
cached_val2_list[i],
cached_conv1_list[i],
cached_conv2_list[i],
]
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i])
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(processed_lens_list[i])
return state_list
def streaming_forward(
features: Tensor,
feature_lens: Tensor,
model: nn.Module,
states: List[Tensor],
chunk_size: int,
left_context_len: int,
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""
Returns encoder outputs, output lengths, and updated states.
"""
cached_embed_left_pad = states[-2]
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = model.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
chunk_size = int(params.chunk_size)
left_context_len = int(params.left_context_frames)
features = []
feature_lens = []
states = []
processed_lens = [] # Used in fast-beam-search
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# Make sure the length after encoder_embed is at least 1.
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
tail_length = chunk_size * 2 + 7 + 2 * 3
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPS,
)
states = stack_states(states)
encoder_out, encoder_out_lens, new_states = streaming_forward(
features=features,
feature_lens=feature_lens,
model=model,
states=states,
chunk_size=chunk_size,
left_context_len=left_context_len,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=device)
processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
states = unstack_states(new_states)
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 100
decode_results = []
# Contain decode streams currently running.
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = get_init_states(model=model, batch_size=1, device=device)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
# - this is to avoid sending [-32k,+32k] signal in...
# - some lhotse AudioTransform classes can make the signal
# be out of range [-1, 1], hence the tolerance 10
assert (
np.abs(audio).max() <= 10
), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature, tail_pad_len=30)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
assert params.causal, params.causal
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
multi_dataset = MultiDataset(args)
def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2
if T <= 0:
logging.warning(
f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}"
)
return T > 0
test_sets_cuts = multi_dataset.test_cuts()
test_sets = test_sets_cuts.keys()
test_cuts = [test_sets_cuts[k] for k in test_sets]
for test_set, test_cut in zip(test_sets, test_cuts):
logging.info(f"Decoding {test_set}")
test_cut = test_cut.filter(remove_short_utt)
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -10,7 +10,7 @@ The above information is from the [CSTR VCTK website](https://datashare.ed.ac.uk
This recipe provides a VITS model trained on the VCTK dataset.
Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-vctk-vits-2023-12-05), note that this model was pretrained on the Edinburgh DataShare VCTK dataset.
Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-vctk-vits-2024-03-18), note that this model was pretrained on the Edinburgh DataShare VCTK dataset.
For tutorial and more details, please refer to the [VITS documentation](https://k2-fsa.github.io/icefall/recipes/TTS/vctk/vits.html).
@ -21,7 +21,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--world-size 4 \
--num-epochs 1000 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
--tokens data/tokens.txt
--max-duration 350

View File

@ -1,104 +0,0 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file reads the texts in given manifest and generates the file that maps tokens to IDs.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict
from lhotse import load_manifest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-file",
type=Path,
default=Path("data/spectrogram/vctk_cuts_all.jsonl.gz"),
help="Path to the manifest file",
)
parser.add_argument(
"--tokens",
type=Path,
default=Path("data/tokens.txt"),
help="Path to the tokens",
)
return parser.parse_args()
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_token2id(manifest_file: Path) -> Dict[str, int]:
"""Return a dict that maps token to IDs."""
extra_tokens = [
"<blk>", # 0 for blank
"<sos/eos>", # 1 for sos and eos symbols.
"<unk>", # 2 for OOV
]
all_tokens = set()
cut_set = load_manifest(manifest_file)
for cut in cut_set:
# Each cut only contain one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
for t in cut.tokens:
all_tokens.add(t)
all_tokens = extra_tokens + list(all_tokens)
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
return token2id
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
manifest_file = Path(args.manifest_file)
out_file = Path(args.tokens)
token2id = get_token2id(manifest_file)
write_mapping(out_file, token2id)

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/local/prepare_token_file.py

View File

@ -24,9 +24,9 @@ This file reads the texts in given manifest and save the new cuts with phoneme t
import logging
from pathlib import Path
import g2p_en
import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest
from piper_phonemize import phonemize_espeak
from tqdm.auto import tqdm
@ -37,17 +37,20 @@ def prepare_tokens_vctk():
partition = "all"
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
g2p = g2p_en.G2p()
new_cuts = []
for cut in tqdm(cut_set):
# Each cut only contains one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
text = cut.supervisions[0].text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
cut.tokens = g2p(text)
tokens_list = phonemize_espeak(text, "en-us")
tokens = []
for t in tokens_list:
tokens.extend(t)
cut.tokens = tokens
new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts)

View File

@ -78,6 +78,13 @@ fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for VCTK"
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - piper_phonemize:
# refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend:
# `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.vctk_with_token.done ]; then
./local/prepare_tokens_vctk.py
mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \
@ -111,14 +118,15 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate token file"
# We assume you have installed g2p_en and espnet_tts_frontend.
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
# - piper_phonemize:
# refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend:
# `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/tokens.txt ]; then
./local/prepare_token_file.py \
--manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \
--tokens data/tokens.txt
./local/prepare_token_file.py --tokens data/tokens.txt
fi
fi

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -97,7 +98,7 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
meta.value = str(value)
onnx.save(model, filename)
@ -160,6 +161,7 @@ def export_model_onnx(
model: nn.Module,
model_filename: str,
vocab_size: int,
n_speakers: int,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
@ -212,10 +214,15 @@ def export_model_onnx(
)
meta_data = {
"model_type": "VITS",
"model_type": "vits",
"version": "1",
"model_author": "k2-fsa",
"comment": "VITS generator",
"comment": "icefall", # must be icefall for models from icefall
"language": "English",
"voice": "en-us", # Choose your language appropriately
"has_espeak": 1,
"n_speakers": n_speakers,
"sample_rate": 22050, # Must match the real sample rate
}
logging.info(f"meta_data: {meta_data}")
@ -231,8 +238,7 @@ def main():
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
with open(args.speakers) as f:
@ -265,6 +271,7 @@ def main():
model,
model_filename,
params.vocab_size,
params.num_spks,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")

View File

@ -135,14 +135,16 @@ def infer_dataset(
batch_size = len(batch["tokens"])
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
speakers = (
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
.int()
@ -214,8 +216,7 @@ def main():
device = torch.device("cuda", 0)
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
# we need cut ids to display recognition results.

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -122,7 +123,9 @@ def main():
model = OnnxModel(args.model_filename)
text = "I went there to see the land, the people and how their system works, end quote."
tokens = tokenizer.texts_to_token_ids([text])
tokens = tokenizer.texts_to_token_ids(
[text], intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
speaker = torch.tensor([1], dtype=torch.int64) # (1, )

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -342,14 +343,16 @@ def prepare_input(
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
)
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
@ -812,8 +815,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
vctk = VctkTtsDataModule(args)

View File

@ -1,6 +1,7 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao)
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#

View File

@ -803,7 +803,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1187,7 +1187,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1,6 +1,7 @@
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey
# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey
# Zengwei Yao
# Mingshuang Luo)
# Mingshuang Luo,
# Zengrui Jin,)
#
# See ../LICENSE for clarification regarding multiple authors
#
@ -16,9 +17,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
@ -653,7 +655,13 @@ def attach_diagnostics(
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
try:
parameter.register_hook(param_backward_hook)
except:
logging.warning(
f"Warning: could not register backward hook for parameter {name}, "
f"it might not be differentiable."
)
return ans

View File

@ -1,4 +1,6 @@
# Copyright 2021-2022 Xiaomi Corporation (authors: Zengwei Yao, Daniel Povey)
# Copyright 2021-2024 Xiaomi Corporation (authors: Zengwei Yao,
# Daniel Povey,
# Zengrui Jin,)
#
# See ../../LICENSE for clarification regarding multiple authors
#
@ -77,7 +79,13 @@ def register_inf_check_hooks(model: nn.Module) -> None:
if not torch.isfinite(grad.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.param_grad is not finite")
try:
parameter.register_hook(param_backward_hook)
except:
logging.warning(
f"Warning: could not register backward hook for parameter {name}, "
f"it might not be differentiable."
)
def _test_inf_check_hooks():

View File

@ -1080,10 +1080,12 @@ def write_surt_error_stats(
print(
f"{cut_id}:\t"
+ " ".join(
(
(
ref_word
if ref_word == hyp_word
else f"({ref_word}->{hyp_word})"
)
for ref_word, hyp_word in ali
)
),