init commit

This commit is contained in:
jinzr 2024-09-04 22:16:41 +08:00
parent cea0dbe7b1
commit dd82686a0f
68 changed files with 10210 additions and 0 deletions

View File

@ -0,0 +1,160 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao,)
# 2024 The Chinese Univ. of HK (authors: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the LibriTTS dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=True,
help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="""Sampling rate of the audio for computing fbank, the default value for LibriTTS is 24000, audio files will be resampled if a different sample rate is provided""",
)
return parser.parse_args()
def compute_fbank_libritts(
dataset: Optional[str] = None,
sampling_rate: int = 24000,
perturb_speed: Optional[bool] = True,
):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(32, os.cpu_count())
num_mel_bins = 80
if dataset is None:
dataset_parts = (
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
)
else:
dataset_parts = dataset.split(" ", -1)
prefix = "libritts"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if sampling_rate != 24000:
logging.info(f"Resampling audio to {sampling_rate}")
cut_set = cut_set.resample(sampling_rate)
if "train" in partition:
if perturb_speed:
logging.info(f"Doing speed perturb")
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
compute_fbank_libritts(
dataset=args.dataset,
sampling_rate=args.sampling_rate,
perturb_speed=args.perturb_speed,
)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -0,0 +1,107 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao,)
# 2024 The Chinese Univ. of HK (authors: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the VCTK dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/spectrogram.
"""
import logging
import os
from pathlib import Path
import torch
from lhotse import (
CutSet,
LilcomChunkyWriter,
Spectrogram,
SpectrogramConfig,
load_manifest,
)
from lhotse.audio import RecordingSet
from lhotse.supervision import SupervisionSet
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_spectrogram_libritts():
src_dir = Path("data/manifests")
output_dir = Path("data/spectrogram")
num_jobs = min(32, os.cpu_count())
sampling_rate = 24000
frame_length = 1024 / sampling_rate # (in second)
frame_shift = 256 / sampling_rate # (in second)
use_fft_mag = True
prefix = "libritts"
suffix = "jsonl.gz"
partition = "all"
recordings = load_manifest(
src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet
).resample(sampling_rate=sampling_rate)
supervisions = load_manifest(
src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet
)
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=frame_length,
frame_shift=frame_shift,
use_fft_mag=use_fft_mag,
)
extractor = Spectrogram(config)
with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
return
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=recordings, supervisions=supervisions
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_spectrogram_libritts()

View File

@ -0,0 +1,341 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
# 2024 The Chinese Univ. of HK (authors: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
"""
from lhotse import load_manifest_lazy
def main():
paths = [
"./data/fbank/libritts_cuts_train-clean-100.jsonl.gz",
"./data/fbank/libritts_cuts_train-clean-360.jsonl.gz",
"./data/fbank/libritts_cuts_train-other-500.jsonl.gz",
"./data/fbank/libritts_cuts_dev-clean.jsonl.gz",
"./data/fbank/libritts_cuts_dev-other.jsonl.gz",
"./data/fbank/libritts_cuts_test-clean.jsonl.gz",
"./data/fbank/libritts_cuts_test-other.jsonl.gz",
]
for path in paths:
cuts = load_manifest_lazy(path)
cuts.describe()
if __name__ == "__main__":
main()
"""
./data/fbank/libritts_cuts_train-clean-100.jsonl.gz statistics:
________________________________________
_ Cuts count: _ 33236 _
________________________________________
_ Total duration (hh:mm:ss) _ 53:47:18 _
________________________________________
_ mean _ 5.8 _
________________________________________
_ std _ 4.6 _
________________________________________
_ min _ 0.2 _
________________________________________
_ 25% _ 2.4 _
________________________________________
_ 50% _ 4.5 _
________________________________________
_ 75% _ 7.9 _
________________________________________
_ 99% _ 21.4 _
________________________________________
_ 99.5% _ 23.7 _
________________________________________
_ 99.9% _ 27.8 _
________________________________________
_ max _ 33.2 _
________________________________________
_ Recordings available: _ 33236 _
________________________________________
_ Features available: _ 33236 _
________________________________________
_ Supervisions available: _ 33236 _
________________________________________
SUPERVISION custom fields:
Speech duration statistics:
__________________________________________________________________
_ Total speech duration _ 53:47:18 _ 100.00% of recording _
__________________________________________________________________
_ Total speaking time duration _ 53:47:18 _ 100.00% of recording _
__________________________________________________________________
_ Total silence duration _ 00:00:01 _ 0.00% of recording _
__________________________________________________________________
./data/fbank/libritts_cuts_train-clean-360.jsonl.gz statistics:
_________________________________________
_ Cuts count: _ 116500 _
_________________________________________
_ Total duration (hh:mm:ss) _ 191:17:42 _
_________________________________________
_ mean _ 5.9 _
_________________________________________
_ std _ 4.6 _
_________________________________________
_ min _ 0.1 _
_________________________________________
_ 25% _ 2.4 _
_________________________________________
_ 50% _ 4.6 _
_________________________________________
_ 75% _ 8.1 _
_________________________________________
_ 99% _ 21.3 _
_________________________________________
_ 99.5% _ 23.4 _
_________________________________________
_ 99.9% _ 27.4 _
_________________________________________
_ max _ 40.4 _
_________________________________________
_ Recordings available: _ 116500 _
_________________________________________
_ Features available: _ 116500 _
_________________________________________
_ Supervisions available: _ 116500 _
_________________________________________
SUPERVISION custom fields:
Speech duration statistics:
___________________________________________________________________
_ Total speech duration _ 191:17:42 _ 100.00% of recording _
___________________________________________________________________
_ Total speaking time duration _ 191:17:42 _ 100.00% of recording _
___________________________________________________________________
_ Total silence duration _ 00:00:01 _ 0.00% of recording _
___________________________________________________________________
./data/fbank/libritts_cuts_train-other-500.jsonl.gz statistics:
_________________________________________
_ Cuts count: _ 205043 _
_________________________________________
_ Total duration (hh:mm:ss) _ 310:04:36 _
_________________________________________
_ mean _ 5.4 _
_________________________________________
_ std _ 4.4 _
_________________________________________
_ min _ 0.1 _
_________________________________________
_ 25% _ 2.3 _
_________________________________________
_ 50% _ 4.2 _
_________________________________________
_ 75% _ 7.3 _
_________________________________________
_ 99% _ 20.6 _
_________________________________________
_ 99.5% _ 22.8 _
_________________________________________
_ 99.9% _ 27.4 _
_________________________________________
_ max _ 43.9 _
_________________________________________
_ Recordings available: _ 205043 _
_________________________________________
_ Features available: _ 205043 _
_________________________________________
_ Supervisions available: _ 205043 _
_________________________________________
SUPERVISION custom fields:
Speech duration statistics:
___________________________________________________________________
_ Total speech duration _ 310:04:36 _ 100.00% of recording _
___________________________________________________________________
_ Total speaking time duration _ 310:04:36 _ 100.00% of recording _
___________________________________________________________________
_ Total silence duration _ 00:00:01 _ 0.00% of recording _
___________________________________________________________________
./data/fbank/libritts_cuts_dev-clean.jsonl.gz statistics:
________________________________________
_ Cuts count: _ 5736 _
________________________________________
_ Total duration (hh:mm:ss) _ 08:58:13 _
________________________________________
_ mean _ 5.6 _
________________________________________
_ std _ 4.3 _
________________________________________
_ min _ 0.3 _
________________________________________
_ 25% _ 2.4 _
________________________________________
_ 50% _ 4.4 _
________________________________________
_ 75% _ 7.8 _
________________________________________
_ 99% _ 19.9 _
________________________________________
_ 99.5% _ 21.9 _
________________________________________
_ 99.9% _ 26.3 _
________________________________________
_ max _ 30.1 _
________________________________________
_ Recordings available: _ 5736 _
________________________________________
_ Features available: _ 5736 _
________________________________________
_ Supervisions available: _ 5736 _
________________________________________
SUPERVISION custom fields:
Speech duration statistics:
__________________________________________________________________
_ Total speech duration _ 08:58:13 _ 100.00% of recording _
__________________________________________________________________
_ Total speaking time duration _ 08:58:13 _ 100.00% of recording _
__________________________________________________________________
_ Total silence duration _ 00:00:01 _ 0.00% of recording _
__________________________________________________________________
./data/fbank/libritts_cuts_dev-other.jsonl.gz statistics:
________________________________________
_ Cuts count: _ 4613 _
________________________________________
_ Total duration (hh:mm:ss) _ 06:25:52 _
________________________________________
_ mean _ 5.0 _
________________________________________
_ std _ 4.1 _
________________________________________
_ min _ 0.3 _
________________________________________
_ 25% _ 2.2 _
________________________________________
_ 50% _ 3.8 _
________________________________________
_ 75% _ 6.5 _
________________________________________
_ 99% _ 19.7 _
________________________________________
_ 99.5% _ 24.5 _
________________________________________
_ 99.9% _ 31.0 _
________________________________________
_ max _ 32.6 _
________________________________________
_ Recordings available: _ 4613 _
________________________________________
_ Features available: _ 4613 _
________________________________________
_ Supervisions available: _ 4613 _
________________________________________
SUPERVISION custom fields:
Speech duration statistics:
__________________________________________________________________
_ Total speech duration _ 06:25:52 _ 100.00% of recording _
__________________________________________________________________
_ Total speaking time duration _ 06:25:52 _ 100.00% of recording _
__________________________________________________________________
_ Total silence duration _ 00:00:01 _ 0.00% of recording _
__________________________________________________________________
./data/fbank/libritts_cuts_test-clean.jsonl.gz statistics:
________________________________________
_ Cuts count: _ 4837 _
________________________________________
_ Total duration (hh:mm:ss) _ 08:34:09 _
________________________________________
_ mean _ 6.4 _
________________________________________
_ std _ 5.1 _
________________________________________
_ min _ 0.3 _
________________________________________
_ 25% _ 2.4 _
________________________________________
_ 50% _ 4.8 _
________________________________________
_ 75% _ 8.9 _
________________________________________
_ 99% _ 22.6 _
________________________________________
_ 99.5% _ 24.4 _
________________________________________
_ 99.9% _ 29.6 _
________________________________________
_ max _ 36.7 _
________________________________________
_ Recordings available: _ 4837 _
________________________________________
_ Features available: _ 4837 _
________________________________________
_ Supervisions available: _ 4837 _
________________________________________
SUPERVISION custom fields:
Speech duration statistics:
__________________________________________________________________
_ Total speech duration _ 08:34:09 _ 100.00% of recording _
__________________________________________________________________
_ Total speaking time duration _ 08:34:09 _ 100.00% of recording _
__________________________________________________________________
_ Total silence duration _ 00:00:01 _ 0.00% of recording _
__________________________________________________________________
./data/fbank/libritts_cuts_test-other.jsonl.gz statistics:
________________________________________
_ Cuts count: _ 5120 _
________________________________________
_ Total duration (hh:mm:ss) _ 06:41:31 _
________________________________________
_ mean _ 4.7 _
________________________________________
_ std _ 3.8 _
________________________________________
_ min _ 0.3 _
________________________________________
_ 25% _ 1.8 _
________________________________________
_ 50% _ 3.6 _
________________________________________
_ 75% _ 6.5 _
________________________________________
_ 99% _ 17.8 _
________________________________________
_ 99.5% _ 20.4 _
________________________________________
_ 99.9% _ 23.8 _
________________________________________
_ max _ 27.3 _
________________________________________
_ Recordings available: _ 5120 _
________________________________________
_ Features available: _ 5120 _
________________________________________
_ Supervisions available: _ 5120 _
________________________________________
SUPERVISION custom fields:
Speech duration statistics:
__________________________________________________________________
_ Total speech duration _ 06:41:31 _ 100.00% of recording _
__________________________________________________________________
_ Total speaking time duration _ 06:41:31 _ 100.00% of recording _
__________________________________________________________________
_ Total silence duration _ 00:00:01 _ 0.00% of recording _
__________________________________________________________________
"""

View File

@ -0,0 +1,71 @@
#!/usr/bin/env python3
# Copyright 2022-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao,)
# 2024 The Chinese Univ. of HK (authors: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks the following assumptions of the generated manifest:
- Single supervision per cut
We will add more checks later if needed.
Usage example:
python3 ./local/validate_manifest.py \
./data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
"""
import argparse
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest
from lhotse.dataset.speech_recognition import validate_for_asr
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest",
type=Path,
help="Path to the manifest file",
)
return parser.parse_args()
def main():
args = get_args()
manifest = args.manifest
logging.info(f"Validating {manifest}")
assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest(manifest)
assert isinstance(cut_set, CutSet)
validate_for_asr(cut_set)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

108
egs/libritts/ASR/prepare.sh Executable file
View File

@ -0,0 +1,108 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=0
stop_stage=100
sampling_rate=24000
perturb_speed=true
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/LibriTTS,
# you can create a symlink
#
# ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS
#
if [ ! -d $dl_dir/LibriTTS ]; then
lhotse download libritts $dl_dir
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/musan
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare LibriTTS manifest"
# We assume that you have downloaded the LibriTTS corpus
# to $dl_dir/LibriTTS
mkdir -p data/manifests
if [ ! -e data/manifests/.libritts.done ]; then
lhotse prepare libritts $dl_dir/LibriTTS data/manifests
touch data/manifests/.libritts.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
if [ ! -f data/manifests/.musan_manifests.done ]; then
log "It may take 6 minutes"
mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan_manifests.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute Fbank for LibriTTS"
mkdir -p data/fbank
if [ ! -e data/fbank/.libritts.done ]; then
./local/compute_fbank_libritts.py \
--sampling-rate $sampling_rate \
--perturb-speed $perturb_speed
touch data/fbank/.libritts.done
fi
# Here we shuffle and combine the train-clean-100, train-clean-360 and
# train-other-500 together to form the training set.
if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then
cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \
<(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \
<(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \
shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
fi
if [ ! -e data/fbank/.libritts-validated.done ]; then
log "Validating data/fbank for LibriTTS"
./local/validate_manifest.py \
data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
touch data/fbank/.libritts-validated.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
if [ ! -f data/fbank/.msuan.done ]; then
mkdir -p data/fbank
./local/compute_fbank_musan.py
touch data/fbank/.msuan.done
fi
fi

1
egs/libritts/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

1
egs/libritts/ASR/zipformer/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
swoosh.pdf

View File

@ -0,0 +1,459 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriTTSAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. libritts test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="""When enabled, use 960h LibriTTS.
Otherwise, use the 100h subset.""",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)()
),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/attention_decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,991 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Liyong Guo,
# Quandong Wang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) ctc-greedy-search
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--max-duration 600 \
--decoding-method ctc-greedy-search
(2) ctc-decoding
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--max-duration 600 \
--decoding-method ctc-decoding
(3) 1best
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--max-duration 600 \
--hlg-scale 0.6 \
--decoding-method 1best
(4) nbest
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--max-duration 600 \
--hlg-scale 0.6 \
--decoding-method nbest
(5) nbest-rescoring
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--max-duration 600 \
--hlg-scale 0.6 \
--nbest-scale 1.0 \
--lm-dir data/lm \
--decoding-method nbest-rescoring
(6) whole-lattice-rescoring
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--max-duration 600 \
--hlg-scale 0.6 \
--nbest-scale 1.0 \
--lm-dir data/lm \
--decoding-method whole-lattice-rescoring
(7) attention-decoder-rescoring-no-ngram
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--use-attention-decoder 1 \
--max-duration 100 \
--decoding-method attention-decoder-rescoring-no-ngram
(8) attention-decoder-rescoring-with-ngram
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--use-attention-decoder 1 \
--max-duration 100 \
--hlg-scale 0.6 \
--nbest-scale 1.0 \
--lm-dir data/lm \
--decoding-method attention-decoder-rescoring-with-ngram
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriTTSAsrDataModule
from lhotse import set_caching_enabled
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import (
ctc_greedy_search,
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder_no_ngram,
rescore_with_attention_decoder_with_ngram,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--decoding-method",
type=str,
default="ctc-decoding",
help="""Decoding method.
Supported values are:
- (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (2) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (3) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (4) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (5) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (6) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
you have trained an RNN LM using ./rnn_lm/train.py
- (7) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
- (8) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
lattice, rescore them with the attention decoder.
- (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
rescored lattice, rescore them with the attention decoder.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=1.0,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--hlg-scale",
type=float,
default=0.6,
help="""The scale to be applied to `hlg.scores`.
""",
)
parser.add_argument(
"--lm-dir",
type=str,
default="data/lm",
help="""The n-gram LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
parser.add_argument(
"--skip-scoring",
type=str2bool,
default=False,
help="""Skip scoring, but still save the ASR output (for eval sets).""",
)
add_model_arguments(parser)
return parser
def get_decoding_params() -> AttributeDict:
"""Parameters for decoding."""
params = AttributeDict(
{
"frame_shift_ms": 10,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict,
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
- params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
- params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
- params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
HLG:
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.decoding_method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.decoding_method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
G:
An LM. It is not None when params.decoding_method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
if params.decoding_method == "ctc-greedy-search":
hyps = ctc_greedy_search(ctc_output, encoder_out_lens)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(hyps)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-greedy-search"
return {key: hyps}
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
torch.div(
supervisions["start_frame"],
params.subsampling_factor,
rounding_mode="floor",
),
torch.div(
supervisions["num_frames"],
params.subsampling_factor,
rounding_mode="floor",
),
),
1,
).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice(
nnet_output=ctc_output,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.decoding_method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps} # note: returns words
if params.decoding_method == "attention-decoder-rescoring-no-ngram":
best_path_dict = rescore_with_attention_decoder_no_ngram(
lattice=lattice,
num_paths=params.num_paths,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
nbest_scale=params.nbest_scale,
)
ans = dict()
for a_scale_str, best_path in best_path_dict.items():
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
ans[a_scale_str] = hyps
return ans
if params.decoding_method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa
return {key: hyps}
if params.decoding_method in ["1best", "nbest"]:
if params.decoding_method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no-rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps} # note: returns BPE tokens
assert params.decoding_method in [
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder-rescoring-with-ngram",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.decoding_method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
elif params.decoding_method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
elif params.decoding_method == "attention-decoder-rescoring-with-ngram":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_attention_decoder_with_ngram(
lattice=rescored_lattice,
num_paths=params.num_paths,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
nbest_scale=params.nbest_scale,
)
else:
assert False, f"Unsupported decoding method: {params.decoding_method}"
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
ans = None
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.decoding_method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.decoding_method is ctc-decoding.
word_table:
It is the word symbol table.
G:
An LM. It is not None when params.decoding_method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch,
word_table=word_table,
G=G,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_asr_output(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
"""
Save text produced by ASR.
"""
for key, results in results_dict.items():
recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recogs_filename, texts=results)
logging.info(f"The transcripts are stored in {recogs_filename}")
def save_wer_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
if params.decoding_method in (
"attention-decoder-rescoring-with-ngram",
"whole-lattice-rescoring",
):
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w", encoding="utf8") as fd:
wer = write_error_stats(
fd, f"{test_set_name}_{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
logging.info(f"Wrote detailed error stats to {errs_filename}")
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(wer_filename, "w", encoding="utf8") as fd:
print("settings\tWER", file=fd)
for key, val in test_set_wers:
print(f"{key}\t{val}", file=fd)
s = f"\nFor {test_set_name}, WER of different settings are:\n"
note = f"\tbest for {test_set_name}"
for key, val in test_set_wers:
s += f"{key}\t{val}{note}\n"
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriTTSAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.lm_dir = Path(args.lm_dir)
params = get_params()
# add decoding params
params.update(get_decoding_params())
params.update(vars(args))
# enable AudioCache
set_caching_enabled(True) # lhotse
assert params.decoding_method in (
"ctc-greedy-search",
"ctc-decoding",
"1best",
"nbest",
"nbest-rescoring",
"whole-lattice-rescoring",
"nbest-oracle",
"attention-decoder-rescoring-no-ngram",
"attention-decoder-rescoring-with-ngram",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"_left-context-{params.left_context_frames}"
if params.use_averaged_model:
params.suffix += "_use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
params.vocab_size = num_classes
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = 0
params.eos_id = 1
params.sos_id = 1
if params.decoding_method in [
"ctc-greedy-search",
"ctc-decoding",
"attention-decoder-rescoring-no-ngram",
]:
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
)
assert HLG.requires_grad is False
HLG.scores *= params.hlg_scale
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.decoding_method in (
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder-rescoring-with-ngram",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.decoding_method in [
"whole-lattice-rescoring",
"attention-decoder-rescoring-with-ngram",
]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriTTSAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table,
G=G,
)
save_asr_output(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
if not params.skip_scoring:
save_wer_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decode_stream.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx-ctc.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx-streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/generate_averaged_model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/jit_pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/label_smoothing.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/my_profile.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_check.py

View File

@ -0,0 +1,324 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX exported models and uses them to decode the test sets.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--causal False
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
2. Run this file
./zipformer/onnx_decode.py \
--exp-dir $repo/exp \
--max-duration 600 \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
"""
import argparse
import logging
import time
from pathlib import Path
from typing import List, Tuple
import torch
import torch.nn as nn
from asr_datamodule import LibriTTSAsrDataModule
from k2 import SymbolTable
from onnx_pretrained import OnnxModel, greedy_search
from icefall.utils import setup_logger, store_transcripts, write_error_stats
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
return parser
def decode_one_batch(
model: OnnxModel, token_table: SymbolTable, batch: dict
) -> List[List[str]]:
"""Decode one batch and return the result.
Currently it only greedy_search is supported.
Args:
model:
The neural model.
token_table:
The token table.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoded results for each utterance.
"""
feature = batch["inputs"]
assert feature.ndim == 3
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
hyps = greedy_search(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
hyps = [token_ids_to_words(h).split() for h in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
model: nn.Module,
token_table: SymbolTable,
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
token_table:
The token table.
Returns:
- A list of tuples. Each tuple contains three elements:
- cut_id,
- reference transcript,
- predicted result.
- The total duration (in seconds) of the dataset.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
log_interval = 10
total_duration = 0
results = []
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results, total_duration
def save_results(
res_dir: Path,
test_set_name: str,
results: List[Tuple[str, List[str], List[str]]],
):
recog_path = res_dir / f"recogs-{test_set_name}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = res_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("WER", file=f)
print(wer, file=f)
s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriTTSAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert (
args.decoding_method == "greedy_search"
), "Only supports greedy_search currently."
res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
setup_logger(f"{res_dir}/log-decode")
logging.info("Decoding started")
device = torch.device("cpu")
logging.info(f"Device: {device}")
token_table = SymbolTable.from_file(args.tokens)
logging.info(vars(args))
logging.info("About to create model")
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriTTSAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time()
results, total_duration = decode_dataset(
dl=test_dl, model=model, token_table=token_table
)
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
logging.info(f"Wave duration: {total_duration:.3f} s")
logging.info(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
save_results(res_dir=res_dir, test_set_name=test_set, results=results)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_pretrained_ctc.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/pretrained_ctc.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/streaming_beam_search.py

View File

@ -0,0 +1,904 @@
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
# Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./zipformer/streaming_decode.py \
--epoch 28 \
--avg 15 \
--causal 1 \
--chunk-size 32 \
--left-context-frames 256 \
--exp-dir ./zipformer/exp \
--decoding-method greedy_search \
--num-decode-streams 2000
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
from asr_datamodule import LibriTTSAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet, set_caching_enabled
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--label",
type=str,
default="",
help="""Extra label of the decoding run.""",
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""",
)
parser.add_argument(
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
parser.add_argument(
"--skip-scoring",
type=str2bool,
default=False,
help="""Skip scoring, but still save the ASR output (for eval sets).""",
)
add_model_arguments(parser)
return parser
def get_init_states(
model: nn.Module,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = model.encoder.get_init_states(batch_size, device)
embed_states = model.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""Stack list of zipformer states that correspond to separate utterances
into a single emformer state, so that it can be used as an input for
zipformer when those utterances are formed into a batch.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance. For element-n,
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
cached_val2, cached_conv1, cached_conv2).
state_list[n][-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
state_list[n][-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Note:
It is the inverse of :func:`unstack_states`.
"""
batch_size = len(state_list)
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
tot_num_layers = (len(state_list[0]) - 2) // 6
batch_states = []
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key = torch.cat(
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn = torch.cat(
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1 = torch.cat(
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2 = torch.cat(
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1 = torch.cat(
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2 = torch.cat(
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
)
batch_states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
cached_embed_left_pad = torch.cat(
[state_list[i][-2] for i in range(batch_size)], dim=0
)
batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
batch_states.append(processed_lens)
return batch_states
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
"""Unstack the zipformer state corresponding to a batch of utterances
into a list of states, where the i-th entry is the state from the i-th
utterance in the batch.
Note:
It is the inverse of :func:`stack_states`.
Args:
batch_states: A list of cached tensors of all encoder layers. For layer-i,
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
cached_conv1, cached_conv2).
state_list[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Returns:
state_list: A list of list. Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance.
"""
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
tot_num_layers = (len(batch_states) - 2) // 6
processed_lens = batch_states[-1]
batch_size = processed_lens.shape[0]
state_list = [[] for _ in range(batch_size)]
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1_list = batch_states[layer_offset + 2].chunk(
chunks=batch_size, dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2_list = batch_states[layer_offset + 3].chunk(
chunks=batch_size, dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1_list = batch_states[layer_offset + 4].chunk(
chunks=batch_size, dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2_list = batch_states[layer_offset + 5].chunk(
chunks=batch_size, dim=0
)
for i in range(batch_size):
state_list[i] += [
cached_key_list[i],
cached_nonlin_attn_list[i],
cached_val1_list[i],
cached_val2_list[i],
cached_conv1_list[i],
cached_conv2_list[i],
]
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i])
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(processed_lens_list[i])
return state_list
def streaming_forward(
features: Tensor,
feature_lens: Tensor,
model: nn.Module,
states: List[Tensor],
chunk_size: int,
left_context_len: int,
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""
Returns encoder outputs, output lengths, and updated states.
"""
cached_embed_left_pad = states[-2]
(
x,
x_lens,
new_cached_embed_left_pad,
) = model.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = model.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
chunk_size = int(params.chunk_size)
left_context_len = int(params.left_context_frames)
features = []
feature_lens = []
states = []
processed_lens = [] # Used in fast-beam-search
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# Make sure the length after encoder_embed is at least 1.
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
tail_length = chunk_size * 2 + 7 + 2 * 3
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPS,
)
states = stack_states(states)
encoder_out, encoder_out_lens, new_states = streaming_forward(
features=features,
feature_lens=feature_lens,
model=model,
states=states,
chunk_size=chunk_size,
left_context_len=left_context_len,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=device)
processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
states = unstack_states(new_states)
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 100
decode_results = []
# Contain decode streams currently running.
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = get_init_states(model=model, batch_size=1, device=device)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
# - this is to avoid sending [-32k,+32k] signal in...
# - some lhotse AudioTransform classes can make the signal
# be out of range [-1, 1], hence the tolerance 10
assert (
np.abs(audio).max() <= 10
), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature, tail_pad_len=30)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
return {key: decode_results}
def save_asr_output(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
"""
Save text produced by ASR.
"""
for key, results in results_dict.items():
recogs_filename = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recogs_filename, texts=results)
logging.info(f"The transcripts are stored in {recogs_filename}")
def save_wer_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
"""
Save WER and per-utterance word alignments.
"""
test_set_wers = dict()
for key, results in results_dict.items():
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w", encoding="utf8") as fd:
wer = write_error_stats(
fd, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info(f"Wrote detailed error stats to {errs_filename}")
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
wer_filename = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(wer_filename, "w", encoding="utf8") as fd:
print("settings\tWER", file=fd)
for key, val in test_set_wers:
print(f"{key}\t{val}", file=fd)
s = f"\nFor {test_set_name}, WER of different settings are:\n"
note = f"\tbest for {test_set_name}"
for key, val in test_set_wers:
s += f"{key}\t{val}{note}\n"
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriTTSAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
# enable AudioCache
set_caching_enabled(True) # lhotse
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
assert params.causal, params.causal
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"_left-context-{params.left_context_frames}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"_beam-{params.beam}"
params.suffix += f"_max-contexts-{params.max_contexts}"
params.suffix += f"_max-states-{params.max_states}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
if params.label:
params.suffix += f"-{params.label}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriTTSAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_asr_output(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
if not params.skip_scoring:
save_wer_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py

View File

@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`."""
import io
import json
import struct
from typing import IO, Any, List, Optional
# format is `ECDC` magic code, followed by the header size as uint32.
# Then an uint8 indicates the protocol version (0.)
# The header is then provided as json and should contain all required
# informations for decoding. A raw stream of bytes is then provided
# and should be interpretable using the json header.
_encodec_header_struct = struct.Struct("!4sBI")
_ENCODEC_MAGIC = b"ECDC"
def write_ecdc_header(fo: IO[bytes], metadata: Any):
meta_dumped = json.dumps(metadata).encode("utf-8")
version = 0
header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, len(meta_dumped))
fo.write(header)
fo.write(meta_dumped)
fo.flush()
def _read_exactly(fo: IO[bytes], size: int) -> bytes:
buf = b""
while len(buf) < size:
new_buf = fo.read(size)
if not new_buf:
raise EOFError(
"Impossible to read enough data from the stream, "
f"{size} bytes remaining."
)
buf += new_buf
size -= len(new_buf)
return buf
def read_ecdc_header(fo: IO[bytes]):
header_bytes = _read_exactly(fo, _encodec_header_struct.size)
magic, version, meta_size = _encodec_header_struct.unpack(header_bytes)
if magic != _ENCODEC_MAGIC:
raise ValueError("File is not in ECDC format.")
if version != 0:
raise ValueError("Version not supported.")
meta_bytes = _read_exactly(fo, meta_size)
return json.loads(meta_bytes.decode("utf-8"))
class BitPacker:
"""Simple bit packer to handle ints with a non standard width, e.g. 10 bits.
Note that for some bandwidth (1.5, 3), the codebook representation
will not cover an integer number of bytes.
Args:
bits (int): number of bits per value that will be pushed.
fo (IO[bytes]): file-object to push the bytes to.
"""
def __init__(self, bits: int, fo: IO[bytes]):
self._current_value = 0
self._current_bits = 0
self.bits = bits
self.fo = fo
def push(self, value: int):
"""Push a new value to the stream. This will immediately
write as many uint8 as possible to the underlying file-object."""
self._current_value += value << self._current_bits
self._current_bits += self.bits
while self._current_bits >= 8:
lower_8bits = self._current_value & 0xFF
self._current_bits -= 8
self._current_value >>= 8
self.fo.write(bytes([lower_8bits]))
def flush(self):
"""Flushes the remaining partial uint8, call this at the end
of the stream to encode."""
if self._current_bits:
self.fo.write(bytes([self._current_value]))
self._current_value = 0
self._current_bits = 0
self.fo.flush()
class BitUnpacker:
"""BitUnpacker does the opposite of `BitPacker`.
Args:
bits (int): number of bits of the values to decode.
fo (IO[bytes]): file-object to push the bytes to.
"""
def __init__(self, bits: int, fo: IO[bytes]):
self.bits = bits
self.fo = fo
self._mask = (1 << bits) - 1
self._current_value = 0
self._current_bits = 0
def pull(self) -> Optional[int]:
"""
Pull a single value from the stream, potentially reading some
extra bytes from the underlying file-object.
Returns `None` when reaching the end of the stream.
"""
while self._current_bits < self.bits:
buf = self.fo.read(1)
if not buf:
return None
character = buf[0]
self._current_value += character << self._current_bits
self._current_bits += 8
out = self._current_value & self._mask
self._current_value >>= self.bits
self._current_bits -= self.bits
return out
def test():
import torch
torch.manual_seed(1234)
for rep in range(4):
length: int = torch.randint(10, 2_000, (1,)).item()
bits: int = torch.randint(1, 16, (1,)).item()
tokens: List[int] = torch.randint(2**bits, (length,)).tolist()
rebuilt: List[int] = []
buf = io.BytesIO()
packer = BitPacker(bits, buf)
for token in tokens:
packer.push(token)
packer.flush()
buf.seek(0)
unpacker = BitUnpacker(bits, buf)
while True:
value = unpacker.pull()
if value is None:
break
rebuilt.append(value)
assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens))
# The flushing mechanism might lead to "ghost" values at the end of the stream.
assert len(rebuilt) <= len(tokens) + 8 // bits, (
len(rebuilt),
len(tokens),
bits,
)
for idx, (a, b) in enumerate(zip(tokens, rebuilt)):
assert a == b, (idx, a, b)
if __name__ == "__main__":
test()

View File

@ -0,0 +1,271 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriTTSCodecDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="Codec data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/spectrogram"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=8,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=False,
return_spk_ids=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create valid dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")

View File

@ -0,0 +1,117 @@
from typing import List, Tuple
import torch
import torch.nn as nn
from models.discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT
from torch.nn import AvgPool1d
class MultiPeriodDiscriminator(nn.Module):
def __init__(self):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorP(2),
DiscriminatorP(3),
DiscriminatorP(5),
DiscriminatorP(7),
DiscriminatorP(11),
]
)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class MultiScaleDiscriminator(nn.Module):
def __init__(self):
super(MultiScaleDiscriminator, self).__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorS(),
DiscriminatorS(),
DiscriminatorS(),
]
)
self.meanpools = nn.ModuleList(
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
if i != 0:
y = self.meanpools[i - 1](y)
y_hat = self.meanpools[i - 1](y_hat)
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class MultiScaleSTFTDiscriminator(nn.Module):
"""Multi-Scale STFT (MS-STFT) discriminator.
Args:
filters (int): Number of filters in convolutions
in_channels (int): Number of input channels. Default: 1
out_channels (int): Number of output channels. Default: 1
n_ffts (Sequence[int]): Size of FFT for each scale
hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
win_lengths (Sequence[int]): Window size for each scale
**kwargs: additional args for STFTDiscriminator
"""
def __init__(
self,
filters: int,
in_channels: int = 1,
out_channels: int = 1,
n_ffts: List[int] = [1024, 2048, 512, 256, 128],
hop_lengths: List[int] = [256, 512, 128, 64, 32],
win_lengths: List[int] = [1024, 2048, 512, 256, 128],
**kwargs
):
super().__init__()
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
self.discriminators = nn.ModuleList(
[
DiscriminatorSTFT(
filters,
in_channels=in_channels,
out_channels=out_channels,
n_fft=n_ffts[i],
win_length=win_lengths[i],
hop_length=hop_lengths[i],
**kwargs
)
for i in range(len(n_ffts))
]
)
self.num_discriminators = len(self.discriminators)
def forward(self, x: torch.Tensor):
logits = []
fmaps = []
for disc in self.discriminators:
logit, fmap = disc(x)
logits.append(logit)
fmaps.append(fmap)
return logits, fmaps

View File

@ -0,0 +1,261 @@
import math
import random
from typing import List
import numpy as np
import torch
from loss import loss_dis, loss_g
from torch import nn
from torch.cuda.amp import autocast
class Encodec(nn.Module):
def __init__(
self,
sample_rate: int,
target_bandwidths: List[float],
params: dict,
encoder: nn.Module,
quantizer: nn.Module,
decoder: nn.Module,
multi_scale_discriminator: nn.Module,
multi_period_discriminator: nn.Module,
multi_scale_stft_discriminator: nn.Module,
cache_generator_outputs: bool = True,
):
super(Encodec, self).__init__()
self.params = params
# setup the generator
self.sample_rate = sample_rate
self.encoder = encoder
self.quantizer = quantizer
self.decoder = decoder
self.ratios = encoder.ratios
self.hop_length = np.prod(self.ratios)
self.frame_rate = math.ceil(self.sample_rate / np.prod(self.ratios))
self.target_bandwidths = target_bandwidths
# discriminators
self.multi_scale_discriminator = multi_scale_discriminator
self.multi_period_discriminator = multi_period_discriminator
self.multi_scale_stft_discriminator = multi_scale_stft_discriminator
# cache
self.cache_generator_outputs = cache_generator_outputs
self._cache = None
def _forward_generator(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
global_step: int,
return_sample: bool = False,
):
"""Perform generator forward.
Args:
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
global_step (int): Global step.
return_sample (bool): Return the generator output.
Returns:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
"""
# setup
speech = speech.unsqueeze(1)
# calculate generator outputs
reuse_cache = True
if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False
e = self.encoder(speech)
bw = random.choice(self.target_bandwidths)
quantized, codes, bandwidth, commit_loss = self.quantizer(
e, self.frame_rate, bw
)
speech_hat = self.decoder(quantized)
else:
speech_hat = self._cache
# store cache
if self.training and self.cache_generator_outputs and not reuse_cache:
self._cache = speech_hat
# calculate discriminator outputs
y_hat, fmap_hat = self.multi_scale_stft_discriminator(speech_hat.contiguous())
with torch.no_grad():
# do not store discriminator gradient in generator turn
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
speech.contiguous(),
speech_hat.contiguous(),
)
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
speech.contiguous(),
speech_hat.contiguous(),
)
# calculate losses
with autocast(enabled=False):
loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
commit_loss,
speech,
speech_hat,
fmap,
fmap_hat,
y,
y_hat,
global_step,
y_p,
y_p_hat,
y_s,
y_s_hat,
fmap_p,
fmap_p_hat,
fmap_s,
fmap_s_hat,
args=self.params,
)
stats = dict(
generator_loss=loss.item(),
generator_reconstruction_loss=rec_loss.item(),
generator_feature_loss=feat_loss.item(),
generator_adv_loss=adv_loss.item(),
generator_commit_loss=commit_loss.item(),
d_weight=d_weight.item(),
)
if return_sample:
stats["returned_sample"] = (
speech_hat[0].data.cpu().numpy(),
speech[0].data.cpu().numpy(),
fmap_hat[0][0].data.cpu().numpy(),
fmap[0][0].data.cpu().numpy(),
)
# reset cache
if reuse_cache or not self.training:
self._cache = None
return loss, stats
def _forward_discriminator(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
global_step: int,
):
"""
Args:
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
global_step (int): Global step.
Returns:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
"""
# setup
speech = speech.unsqueeze(1)
# calculate generator outputs
reuse_cache = True
if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False
e = self.encoder(speech)
bw = random.choice(self.target_bandwidths)
quantized, codes, bandwidth, commit_loss = self.quantizer(
e, self.frame_rate, bw
)
speech_hat = self.decoder(quantized)
else:
speech_hat = self._cache
# store cache
if self.training and self.cache_generator_outputs and not reuse_cache:
self._cache = speech_hat
# calculate discriminator outputs
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
y_hat, fmap_hat = self.multi_scale_stft_discriminator(
speech_hat.contiguous().detach()
)
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
speech.contiguous(),
speech_hat.contiguous().detach(),
)
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
speech.contiguous(),
speech_hat.contiguous().detach(),
)
# calculate losses
with autocast(enabled=False):
loss = loss_dis(
y,
y_hat,
fmap,
fmap_hat,
y_p,
y_p_hat,
fmap_p,
fmap_p_hat,
y_s,
y_s_hat,
fmap_s,
fmap_s_hat,
global_step,
args=self.params,
)
stats = dict(
discriminator_loss=loss.item(),
)
# reset cache
if reuse_cache or not self.training:
self._cache = None
return loss, stats
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
global_step: int,
return_sample: bool,
forward_generator: bool,
):
if forward_generator:
return self._forward_generator(
speech=speech,
speech_lengths=speech_lengths,
global_step=global_step,
return_sample=return_sample,
)
else:
return self._forward_discriminator(
speech=speech,
speech_lengths=speech_lengths,
global_step=global_step,
)
def encode(self, x, target_bw=None, st=None):
e = self.encoder(x)
if target_bw is None:
bw = self.target_bandwidths[-1]
else:
bw = target_bw
if st is None:
st = 0
codes = self.quantizer.encode(e, self.frame_rate, bw, st)
return codes
def decode(self, codes):
quantized = self.quantizer.decode(codes)
o = self.decoder(quantized)
return o

View File

@ -0,0 +1,298 @@
import torch
import torch.nn.functional as F
from torchaudio.transforms import MelSpectrogram
def adversarial_g_loss(y_disc_gen):
"""Hinge loss"""
loss = 0.0
for i in range(len(y_disc_gen)):
stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze()
loss += stft_loss
return loss / len(y_disc_gen)
def feature_loss(fmap_r, fmap_gen):
loss = 0.0
for i in range(len(fmap_r)):
for j in range(len(fmap_r[i])):
stft_loss = (
(fmap_r[i][j] - fmap_gen[i][j]).abs() / (fmap_r[i][j].abs().mean())
).mean()
loss += stft_loss
return loss / (len(fmap_r) * len(fmap_r[0]))
def sim_loss(y_disc_r, y_disc_gen):
loss = 0.0
for i in range(len(y_disc_r)):
loss += F.mse_loss(y_disc_r[i], y_disc_gen[i])
return loss / len(y_disc_r)
# def sisnr_loss(x, s, eps=1e-8):
# """
# calculate training loss
# input:
# x: separated signal, N x S tensor, estimate value
# s: reference signal, N x S tensor, True value
# Return:
# sisnr: N tensor
# """
# if x.shape != s.shape:
# if x.shape[-1] > s.shape[-1]:
# x = x[:, :s.shape[-1]]
# else:
# s = s[:, :x.shape[-1]]
# def l2norm(mat, keepdim=False):
# return torch.norm(mat, dim=-1, keepdim=keepdim)
# if x.shape != s.shape:
# raise RuntimeError(
# "Dimention mismatch when calculate si-snr, {} vs {}".format(
# x.shape, s.shape))
# x_zm = x - torch.mean(x, dim=-1, keepdim=True)
# s_zm = s - torch.mean(s, dim=-1, keepdim=True)
# t = torch.sum(
# x_zm * s_zm, dim=-1,
# keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
# loss = -20. * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))
# return torch.sum(loss) / x.shape[0]
def reconstruction_loss(x, G_x, args, eps=1e-7):
# NOTE (lsx): hard-coded now
L = args.lambda_wav * F.mse_loss(x, G_x) # wav L1 loss
# loss_sisnr = sisnr_loss(G_x, x) #
# L += 0.01*loss_sisnr
# 2^6=64 -> 2^10=1024
# NOTE (lsx): add 2^11
for i in range(6, 12):
# for i in range(5, 12): # Encodec setting
s = 2**i
melspec = MelSpectrogram(
sample_rate=args.sr,
n_fft=max(s, 512),
win_length=s,
hop_length=s // 4,
n_mels=64,
wkwargs={"device": args.device},
).to(args.device)
S_x = melspec(x)
S_G_x = melspec(G_x)
l1_loss = (S_x - S_G_x).abs().mean()
l2_loss = (
((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean(
dim=-2
)
** 0.5
).mean()
alpha = (s / 2) ** 0.5
L += l1_loss + alpha * l2_loss
return L
def criterion_d(
y_disc_r,
y_disc_gen,
fmap_r_det,
fmap_gen_det,
y_df_hat_r,
y_df_hat_g,
fmap_f_r,
fmap_f_g,
y_ds_hat_r,
y_ds_hat_g,
fmap_s_r,
fmap_s_g,
):
"""Hinge Loss"""
loss = 0.0
loss1 = 0.0
loss2 = 0.0
loss3 = 0.0
for i in range(len(y_disc_r)):
loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[i]).mean()
for i in range(len(y_df_hat_r)):
loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[i]).mean()
for i in range(len(y_ds_hat_r)):
loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[i]).mean()
loss = (
loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 / len(y_ds_hat_r)
) / 3.0
return loss
def criterion_g(
commit_loss,
x,
G_x,
fmap_r,
fmap_gen,
y_disc_r,
y_disc_gen,
y_df_hat_r,
y_df_hat_g,
fmap_f_r,
fmap_f_g,
y_ds_hat_r,
y_ds_hat_g,
fmap_s_r,
fmap_s_g,
args,
):
adv_g_loss = adversarial_g_loss(y_disc_gen)
feat_loss = (
feature_loss(fmap_r, fmap_gen)
+ sim_loss(y_disc_r, y_disc_gen)
+ feature_loss(fmap_f_r, fmap_f_g)
+ sim_loss(y_df_hat_r, y_df_hat_g)
+ feature_loss(fmap_s_r, fmap_s_g)
+ sim_loss(y_ds_hat_r, y_ds_hat_g)
) / 3.0
rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args)
total_loss = (
args.lambda_com * commit_loss
+ args.lambda_adv * adv_g_loss
+ args.lambda_feat * feat_loss
+ args.lambda_rec * rec_loss
)
return total_loss, adv_g_loss, feat_loss, rec_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
def adopt_dis_weight(weight, global_step, threshold=0, value=0.0):
# 0,3,6,9,13....这些时间步不更新dis
if global_step % 3 == 0:
weight = value
return weight
def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
print("last_layer cannot be none")
assert 1 == 2
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 1.0, 1.0).detach()
d_weight = d_weight * args.lambda_adv
return d_weight
def loss_g(
codebook_loss,
speech,
speech_hat,
fmap,
fmap_hat,
y,
y_hat,
global_step,
y_df,
y_df_hat,
y_ds,
y_ds_hat,
fmap_f,
fmap_f_hat,
fmap_s,
fmap_s_hat,
args=None,
):
"""
args:
codebook_loss: commit loss.
speech: ground-truth wav.
speech_hat: reconstructed wav.
fmap: real stft-D feature map.
fmap_hat: fake stft-D feature map.
y: real stft-D logits.
y_hat: fake stft-D logits.
global_step: global training step.
y_df: real MPD logits.
y_df_hat: fake MPD logits.
y_ds: real MSD logits.
y_ds_hat: fake MSD logits.
fmap_f: real MPD feature map.
fmap_f_hat: fake MPD feature map.
fmap_s: real MSD feature map.
fmap_s_hat: fake MSD feature map.
"""
rec_loss = reconstruction_loss(speech.contiguous(), speech_hat.contiguous(), args)
adv_g_loss = adversarial_g_loss(y_hat)
adv_mpd_loss = adversarial_g_loss(y_df_hat)
adv_msd_loss = adversarial_g_loss(y_ds_hat)
adv_loss = (
adv_g_loss + adv_mpd_loss + adv_msd_loss
) / 3.0 # NOTE(lsx): need to divide by 3?
feat_loss = feature_loss(
fmap, fmap_hat
) # + sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits?
feat_loss_mpd = feature_loss(
fmap_f, fmap_f_hat
) # + sim_loss(y_df_hat_r, y_df_hat_g)
feat_loss_msd = feature_loss(
fmap_s, fmap_s_hat
) # + sim_loss(y_ds_hat_r, y_ds_hat_g)
feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
d_weight = torch.tensor(1.0)
disc_factor = adopt_weight(
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
)
if disc_factor == 0.0:
fm_loss_wt = 0
else:
fm_loss_wt = args.lambda_feat
loss = (
rec_loss
+ d_weight * disc_factor * adv_loss
+ fm_loss_wt * feat_loss_tot
+ args.lambda_com * codebook_loss
)
return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
def loss_dis(
y,
y_hat,
fmap,
fmap_hat,
y_df,
y_df_hat,
fmap_f,
fmap_f_hat,
y_ds,
y_ds_hat,
fmap_s,
fmap_s_hat,
global_step,
args,
):
disc_factor = adopt_weight(
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
)
d_loss = disc_factor * criterion_d(
y,
y_hat,
fmap,
fmap_hat,
y_df,
y_df_hat,
fmap_f,
fmap_f_hat,
y_ds,
y_ds_hat,
fmap_s,
fmap_s_hat,
)
return d_loss

View File

@ -0,0 +1,229 @@
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from einops import rearrange
from utils import get_2d_padding, get_padding
from ..modules import NormConv1d, NormConv2d
class DiscriminatorP(nn.Module):
def __init__(
self,
period,
kernel_size=5,
stride=3,
activation: str = "LeakyReLU",
activation_params: dict = {"negative_slope": 0.2},
):
super(DiscriminatorP, self).__init__()
self.period = period
self.activation = getattr(torch.nn, activation)(**activation_params)
self.convs = nn.ModuleList(
[
NormConv2d(
1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0)
),
NormConv2d(
32,
32,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
),
NormConv2d(
32,
32,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
),
NormConv2d(
32,
32,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
),
NormConv2d(32, 32, (kernel_size, 1), 1, padding=(2, 0)),
]
)
self.conv_post = NormConv2d(32, 1, (3, 1), 1, padding=(1, 0))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = self.activation(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorS(nn.Module):
def __init__(
self,
activation: str = "LeakyReLU",
activation_params: dict = {"negative_slope": 0.2},
):
super(DiscriminatorS, self).__init__()
self.activation = getattr(torch.nn, activation)(**activation_params)
self.convs = nn.ModuleList(
[
NormConv1d(1, 32, 15, 1, padding=7),
NormConv1d(32, 32, 41, 2, groups=4, padding=20),
NormConv1d(32, 32, 41, 2, groups=16, padding=20),
NormConv1d(32, 32, 41, 4, groups=16, padding=20),
NormConv1d(32, 32, 41, 4, groups=16, padding=20),
NormConv1d(32, 32, 41, 1, groups=16, padding=20),
NormConv1d(32, 32, 5, 1, padding=2),
]
)
self.conv_post = NormConv1d(32, 1, 3, 1, padding=1)
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = self.activation(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorSTFT(nn.Module):
"""STFT sub-discriminator.
Args:
filters (int): Number of filters in convolutions
in_channels (int): Number of input channels. Default: 1
out_channels (int): Number of output channels. Default: 1
n_fft (int): Size of FFT for each scale. Default: 1024
hop_length (int): Length of hop between STFT windows for each scale. Default: 256
kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
win_length (int): Window size for each scale. Default: 1024
normalized (bool): Whether to normalize by magnitude after stft. Default: True
norm (str): Normalization method. Default: `'weight_norm'`
activation (str): Activation function. Default: `'LeakyReLU'`
activation_params (dict): Parameters to provide to the activation function.
growth (int): Growth factor for the filters. Default: 1
"""
def __init__(
self,
n_filters: int,
in_channels: int = 1,
out_channels: int = 1,
n_fft: int = 1024,
hop_length: int = 256,
win_length: int = 1024,
max_filters: int = 1024,
filters_scale: int = 1,
kernel_size: Tuple[int, int] = (3, 9),
dilations: List[int] = [1, 2, 4],
stride: Tuple[int, int] = (1, 2),
normalized: bool = True,
norm: str = "weight_norm",
activation: str = "LeakyReLU",
activation_params: dict = {"negative_slope": 0.2},
):
super().__init__()
assert len(kernel_size) == 2
assert len(stride) == 2
self.filters = n_filters
self.in_channels = in_channels
self.out_channels = out_channels
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.normalized = normalized
self.activation = getattr(torch.nn, activation)(**activation_params)
self.spec_transform = torchaudio.transforms.Spectrogram(
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window_fn=torch.hann_window,
normalized=self.normalized,
center=False,
pad_mode=None,
power=None,
)
spec_channels = 2 * self.in_channels
self.convs = nn.ModuleList()
self.convs.append(
NormConv2d(
spec_channels,
self.filters,
kernel_size=kernel_size,
padding=get_2d_padding(kernel_size),
)
)
in_chs = min(filters_scale * self.filters, max_filters)
for i, dilation in enumerate(dilations):
out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
self.convs.append(
NormConv2d(
in_chs,
out_chs,
kernel_size=kernel_size,
stride=stride,
dilation=(dilation, 1),
padding=get_2d_padding(kernel_size, (dilation, 1)),
norm=norm,
)
)
in_chs = out_chs
out_chs = min(
(filters_scale ** (len(dilations) + 1)) * self.filters, max_filters
)
self.convs.append(
NormConv2d(
in_chs,
out_chs,
kernel_size=(kernel_size[0], kernel_size[0]),
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
norm=norm,
)
)
self.conv_post = NormConv2d(
out_chs,
self.out_channels,
kernel_size=(kernel_size[0], kernel_size[0]),
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
norm=norm,
)
def forward(self, x: torch.Tensor):
fmap = []
# print('x ', x.shape)
z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
# print('z ', z.shape)
z = torch.cat([z.real, z.imag], dim=1)
# print('cat_z ', z.shape)
z = rearrange(z, "b c w t -> b c t w")
for i, layer in enumerate(self.convs):
z = layer(z)
z = self.activation(z)
# print('z i', i, z.shape)
fmap.append(z)
z = self.conv_post(z)
# print('logit ', z.shape)
return z, fmap

View File

@ -0,0 +1,12 @@
from typing import Tuple
def get_padding(kernel_size, dilation=1) -> int:
return int((kernel_size * dilation - dilation) / 2)
def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)):
return (
((kernel_size[0] - 1) * dilation[0]) // 2,
((kernel_size[1] - 1) * dilation[1]) // 2,
)

View File

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch modules."""
# flake8: noqa
from .conv import (
NormConv1d,
NormConv2d,
NormConvTranspose1d,
NormConvTranspose2d,
SConv1d,
SConvTranspose1d,
pad1d,
unpad1d,
)
from .lstm import SLSTM
from .seanet import SEANetDecoder, SEANetEncoder
from .transformer import StreamingTransformerEncoder

View File

@ -0,0 +1,334 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Convolutional layers wrappers and utilities."""
import logging
import math
from typing import Any, Dict, Tuple
from torch import Tensor, nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm
from .norm import ConvLayerNorm
CONV_NORMALIZATIONS = frozenset(
[
"none",
"weight_norm",
"spectral_norm",
"time_layer_norm",
"layer_norm",
"time_group_norm",
]
)
def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
assert norm in CONV_NORMALIZATIONS
if norm == "weight_norm":
return weight_norm(module)
elif norm == "spectral_norm":
return spectral_norm(module)
else:
# We already check was in CONV_NORMALIZATION, so any other choice
# doesn't need reparametrization.
return module
def get_norm_module(
module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
) -> nn.Module:
"""Return the proper normalization module. If causal is True, this will ensure the returned
module is causal, or return an error if the normalization doesn't support causal evaluation.
"""
assert norm in CONV_NORMALIZATIONS
if norm == "layer_norm":
assert isinstance(module, nn.modules.conv._ConvNd)
return ConvLayerNorm(module.out_channels, **norm_kwargs)
elif norm == "time_group_norm":
if causal:
raise ValueError("GroupNorm doesn't support causal evaluation.")
assert isinstance(module, nn.modules.conv._ConvNd)
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
else:
return nn.Identity()
def get_extra_padding_for_conv1d(
x: Tensor, kernel_size: int, stride: int, padding_total: int = 0
) -> int:
"""See `pad_for_conv1d`."""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad_for_conv1d(x: Tensor, kernel_size: int, stride: int, padding_total: int = 0):
"""Pad for a convolution to make sure that the last window is full.
Extra padding is added at the end. This is required to ensure that we can rebuild
an output of the same length, as otherwise, even with padding, some time steps
might get removed.
For instance, with total padding = 4, kernel size = 4, stride = 2:
0 0 1 2 3 4 5 0 0 # (0s are padding)
1 2 3 # (output frames of a convolution, last 0 is never used)
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
1 2 3 4 # once you removed padding, we are missing one time step !
"""
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
return F.pad(x, (0, extra_padding))
def pad1d(
x: Tensor,
paddings: Tuple[int, int],
mode: str = "zero",
value: float = 0.0,
):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == "reflect":
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
def unpad1d(x: Tensor, paddings: Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left:end]
class NormConv1d(nn.Module):
"""Wrapper around Conv1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(
self,
*args,
causal: bool = False,
norm: str = "none",
norm_kwargs: Dict[str, Any] = {},
**kwargs,
):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConv2d(nn.Module):
"""Wrapper around Conv2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(
self,
*args,
norm: str = "none",
norm_kwargs: Dict[str, Any] = {},
**kwargs,
):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConvTranspose1d(nn.Module):
"""Wrapper around ConvTranspose1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(
self,
*args,
causal: bool = False,
norm: str = "none",
norm_kwargs: Dict[str, Any] = {},
**kwargs,
):
super().__init__()
self.convtr = apply_parametrization_norm(
nn.ConvTranspose1d(*args, **kwargs), norm
)
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class NormConvTranspose2d(nn.Module):
"""Wrapper around ConvTranspose2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(
self,
*args,
norm: str = "none",
norm_kwargs: Dict[str, Any] = {},
**kwargs,
):
super().__init__()
self.convtr = apply_parametrization_norm(
nn.ConvTranspose2d(*args, **kwargs), norm
)
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class SConv1d(nn.Module):
"""Conv1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
causal: bool = False,
norm: str = "none",
norm_kwargs: Dict[str, Any] = {},
pad_mode: str = "reflect",
):
super().__init__()
# warn user on unusual setup between dilation and stride
if stride > 1 and dilation > 1:
logging.warning(
"SConv1d has been initialized with stride > 1 and dilation > 1"
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
)
self.conv = NormConv1d(
in_channels,
out_channels,
kernel_size,
stride,
dilation=dilation,
groups=groups,
bias=bias,
causal=causal,
norm=norm,
norm_kwargs=norm_kwargs,
)
self.causal = causal
self.pad_mode = pad_mode
def forward(self, x):
B, C, T = x.shape
kernel_size = self.conv.conv.kernel_size[0]
stride = self.conv.conv.stride[0]
dilation = self.conv.conv.dilation[0]
padding_total = (kernel_size - 1) * dilation - (stride - 1)
extra_padding = get_extra_padding_for_conv1d(
x, kernel_size, stride, padding_total
)
if self.causal:
# Left padding for causal
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
)
return self.conv(x)
class SConvTranspose1d(nn.Module):
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
causal: bool = False,
norm: str = "none",
trim_right_ratio: float = 1.0,
norm_kwargs: Dict[str, Any] = {},
):
super().__init__()
self.convtr = NormConvTranspose1d(
in_channels,
out_channels,
kernel_size,
stride,
causal=causal,
norm=norm,
norm_kwargs=norm_kwargs,
)
self.causal = causal
self.trim_right_ratio = trim_right_ratio
assert (
self.causal or self.trim_right_ratio == 1.0
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
def forward(self, x):
kernel_size = self.convtr.convtr.kernel_size[0]
stride = self.convtr.convtr.stride[0]
padding_total = kernel_size - stride
y = self.convtr(x)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if self.causal:
# Trim the padding on the right according to the specified ratio
# if trim_right_ratio = 1.0, trim everything from right
padding_right = math.ceil(padding_total * self.trim_right_ratio)
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
return y

View File

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""LSTM layers module."""
from torch import nn
class SLSTM(nn.Module):
"""
LSTM without worrying about the hidden state, nor the layout of the data.
Expects input as convolutional layout.
"""
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
super().__init__()
self.skip = skip
self.lstm = nn.LSTM(dimension, dimension, num_layers)
def forward(self, x):
x = x.permute(2, 0, 1)
y, _ = self.lstm(x)
if self.skip:
y = y + x
y = y.permute(1, 2, 0)
return y

View File

@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Normalization modules."""
from typing import List, Union
import einops
import torch
from torch import nn
class ConvLayerNorm(nn.LayerNorm):
"""
Convolution-friendly LayerNorm that moves channels to last dimensions
before running the normalization and moves them back to original position right after.
"""
def __init__(self, normalized_shape: Union[int, List[int], torch.Size], **kwargs):
super().__init__(normalized_shape, **kwargs)
def forward(self, x):
x = einops.rearrange(x, "b ... t -> b t ...")
x = super().forward(x)
x = einops.rearrange(x, "b t ... -> b ... t")
return

View File

@ -0,0 +1,368 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Encodec SEANet-based encoder and decoder implementation."""
from typing import Any, Dict, List, Optional
import numpy as np
import torch.nn as nn
from modules import SLSTM, SConv1d, SConvTranspose1d
class SEANetResnetBlock(nn.Module):
"""Residual block from SEANet model.
Args:
dim (int): Dimension of the input/output
kernel_sizes (list): List of kernel sizes for the convolutions.
dilations (list): List of dilations for the convolutions.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
compress (int): Reduced dimensionality in residual branches (from Demucs v3)
true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
"""
def __init__(
self,
dim: int,
kernel_sizes: List[int] = [3, 1],
dilations: List[int] = [1, 1],
activation: str = "ELU",
activation_params: Dict = {"alpha": 1.0},
norm: str = "weight_norm",
norm_params: Dict[str, Any] = {},
causal: bool = False,
pad_mode: str = "reflect",
compress: int = 2,
true_skip: bool = True,
):
super().__init__()
assert len(kernel_sizes) == len(
dilations
), "Number of kernel sizes should match number of dilations"
act = getattr(nn, activation)
hidden = dim // compress
block = []
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
in_chs = dim if i == 0 else hidden
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
block += [
act(**activation_params),
SConv1d(
in_chs,
out_chs,
kernel_size=kernel_size,
dilation=dilation,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
pad_mode=pad_mode,
),
]
self.block = nn.Sequential(*block)
self.shortcut: nn.Module
if true_skip:
self.shortcut = nn.Identity()
else:
self.shortcut = SConv1d(
dim,
dim,
kernel_size=1,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
pad_mode=pad_mode,
)
def forward(self, x):
return self.shortcut(x) + self.block(x)
class SEANetEncoder(nn.Module):
"""SEANet encoder.
Args:
channels (int): Audio channels.
dimension (int): Intermediate representation dimension.
n_filters (int): Base width for the model.
n_residual_layers (int): nb of residual layers.
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
that must match the decoder order
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
kernel_size (int): Kernel size for the initial convolution.
last_kernel_size (int): Kernel size for the initial convolution.
residual_kernel_size (int): Kernel size for the residual layers.
dilation_base (int): How much to increase the dilation with each layer.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection in the residual network blocks.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
lstm (int): Number of LSTM layers at the end of the encoder.
"""
def __init__(
self,
channels: int = 1,
dimension: int = 128,
n_filters: int = 32,
n_residual_layers: int = 1,
ratios: List[int] = [8, 5, 4, 2],
activation: str = "ELU",
activation_params: dict = {"alpha": 1.0},
norm: str = "weight_norm",
norm_params: Dict[str, Any] = {},
kernel_size: int = 7,
last_kernel_size: int = 7,
residual_kernel_size: int = 3,
dilation_base: int = 2,
causal: bool = False,
pad_mode: str = "reflect",
true_skip: bool = False,
compress: int = 2,
lstm: int = 2,
):
super().__init__()
self.channels = channels
self.dimension = dimension
self.n_filters = n_filters
self.ratios = list(reversed(ratios))
del ratios
self.n_residual_layers = n_residual_layers
self.hop_length = np.prod(self.ratios) # 计算乘积
act = getattr(nn, activation)
mult = 1
model: List[nn.Module] = [
SConv1d(
channels,
mult * n_filters,
kernel_size,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
pad_mode=pad_mode,
)
]
# Downsample to raw audio scale
for i, ratio in enumerate(self.ratios):
# Add residual layers
for j in range(n_residual_layers):
model += [
SEANetResnetBlock(
mult * n_filters,
kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base**j, 1],
norm=norm,
norm_params=norm_params,
activation=activation,
activation_params=activation_params,
causal=causal,
pad_mode=pad_mode,
compress=compress,
true_skip=true_skip,
)
]
# Add downsampling layers
model += [
act(**activation_params),
SConv1d(
mult * n_filters,
mult * n_filters * 2,
kernel_size=ratio * 2,
stride=ratio,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
pad_mode=pad_mode,
),
]
mult *= 2
if lstm:
model += [SLSTM(mult * n_filters, num_layers=lstm)]
model += [
act(**activation_params),
SConv1d(
mult * n_filters,
dimension,
last_kernel_size,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
pad_mode=pad_mode,
),
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class SEANetDecoder(nn.Module):
"""SEANet decoder.
Args:
channels (int): Audio channels.
dimension (int): Intermediate representation dimension.
n_filters (int): Base width for the model.
n_residual_layers (int): nb of residual layers.
ratios (Sequence[int]): kernel size and stride ratios
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
final_activation (str): Final activation function after all convolutions.
final_activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
kernel_size (int): Kernel size for the initial convolution.
last_kernel_size (int): Kernel size for the initial convolution.
residual_kernel_size (int): Kernel size for the residual layers.
dilation_base (int): How much to increase the dilation with each layer.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection in the residual network blocks.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
lstm (int): Number of LSTM layers at the end of the encoder.
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
If equal to 1.0, it means that all the trimming is done at the right.
"""
def __init__(
self,
channels: int = 1,
dimension: int = 128,
n_filters: int = 32,
n_residual_layers: int = 1,
ratios: List[int] = [8, 5, 4, 2],
activation: str = "ELU",
activation_params: dict = {"alpha": 1.0},
final_activation: Optional[str] = None,
final_activation_params: Optional[dict] = None,
norm: str = "weight_norm",
norm_params: Dict[str, Any] = {},
kernel_size: int = 7,
last_kernel_size: int = 7,
residual_kernel_size: int = 3,
dilation_base: int = 2,
causal: bool = False,
pad_mode: str = "reflect",
true_skip: bool = False,
compress: int = 2,
lstm: int = 2,
trim_right_ratio: float = 1.0,
):
super().__init__()
self.dimension = dimension
self.channels = channels
self.n_filters = n_filters
self.ratios = ratios
del ratios
self.n_residual_layers = n_residual_layers
self.hop_length = np.prod(self.ratios)
act = getattr(nn, activation)
mult = int(2 ** len(self.ratios))
model: List[nn.Module] = [
SConv1d(
dimension,
mult * n_filters,
kernel_size,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
pad_mode=pad_mode,
)
]
if lstm:
model += [SLSTM(mult * n_filters, num_layers=lstm)]
# Upsample to raw audio scale
for i, ratio in enumerate(self.ratios):
# Add upsampling layers
model += [
act(**activation_params),
SConvTranspose1d(
mult * n_filters,
mult * n_filters // 2,
kernel_size=ratio * 2,
stride=ratio,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
trim_right_ratio=trim_right_ratio,
),
]
# Add residual layers
for j in range(n_residual_layers):
model += [
SEANetResnetBlock(
mult * n_filters // 2,
kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base**j, 1],
activation=activation,
activation_params=activation_params,
norm=norm,
norm_params=norm_params,
causal=causal,
pad_mode=pad_mode,
compress=compress,
true_skip=true_skip,
)
]
mult //= 2
# Add final layers
model += [
act(**activation_params),
SConv1d(
n_filters,
channels,
last_kernel_size,
norm=norm,
norm_kwargs=norm_params,
causal=causal,
pad_mode=pad_mode,
),
]
# Add optional final activation to decoder (eg. tanh)
if final_activation is not None:
final_act = getattr(nn, final_activation)
final_activation_params = final_activation_params or {}
model += [final_act(**final_activation_params)]
self.model = nn.Sequential(*model)
def forward(self, z):
y = self.model(z)
return y
def test():
import torch
encoder = SEANetEncoder()
decoder = SEANetDecoder()
x = torch.randn(1, 1, 24000)
z = encoder(x)
print("z ", z.shape)
assert 1 == 2
assert list(z.shape) == [1, 128, 75], z.shape
y = decoder(z)
assert y.shape == x.shape, (x.shape, y.shape)
if __name__ == "__main__":
test()

View File

@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""A streamable transformer."""
import typing as tp
from typing import Any, List, Optional, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
def create_sin_embedding(positions: Tensor, dim: int, max_period: float = 10000):
"""Create time embedding for the given positions, target dimension `dim`."""
# We aim for BTC format
assert dim % 2 == 0
half_dim = dim // 2
adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1)
phase = positions / (max_period ** (adim / (half_dim - 1)))
return torch.cat(
[
torch.cos(phase),
torch.sin(phase),
],
dim=-1,
)
class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer):
def forward(self, x: Tensor, x_past: Tensor, past_context: int): # type: ignore
if self.norm_first:
sa_input = self.norm1(x)
x = x + self._sa_block(sa_input, x_past, past_context)
x = x + self._ff_block(self.norm2(x))
else:
sa_input = x
x = self.norm1(x + self._sa_block(sa_input, x_past, past_context))
x = self.norm2(x + self._ff_block(x))
return x, sa_input
# self-attention block
def _sa_block(self, x: Tensor, x_past: Tensor, past_context: int): # type: ignore
_, T, _ = x.shape
_, H, _ = x_past.shape
queries = x
keys = torch.cat([x_past, x], dim=1)
values = keys
queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1)
keys_pos = torch.arange(T + H, device=x.device).view(1, -1)
delta = queries_pos - keys_pos
valid_access = (delta >= 0) & (delta <= past_context)
x = self.self_attn(
queries, keys, values, attn_mask=~valid_access, need_weights=False
)[0]
return self.dropout1(x)
class StreamingTransformerEncoder(nn.Module):
"""TransformerEncoder with streaming support.
Args:
dim (int): dimension of the data.
hidden_scale (int): intermediate dimension of FF module is this times the dimension.
num_heads (int): number of heads.
num_layers (int): number of layers.
max_period (float): maxium period of cosines in the positional embedding.
past_context (int or None): receptive field for the causal mask, infinite if None.
gelu (bool): if true uses GeLUs, otherwise use ReLUs.
norm_in (bool): normalize the input.
dropout (float): dropout probability.
**kwargs: See `nn.TransformerEncoderLayer`.
"""
def __init__(
self,
dim,
hidden_scale: float = 4.0,
num_heads: int = 8,
num_layers: int = 5,
max_period: float = 10000,
past_context: int = 1000,
gelu: bool = True,
norm_in: bool = True,
dropout: float = 0.0,
**kwargs
):
super().__init__()
assert dim % num_heads == 0
hidden_dim = int(dim * hidden_scale)
self.max_period = max_period
self.past_context = past_context
activation: Any = F.gelu if gelu else F.relu
self.norm_in: nn.Module
if norm_in:
self.norm_in = nn.LayerNorm(dim)
else:
self.norm_in = nn.Identity()
self.layers = nn.ModuleList()
for idx in range(num_layers):
self.layers.append(
StreamingTransformerEncoderLayer(
dim,
num_heads,
hidden_dim,
activation=activation,
batch_first=True,
dropout=dropout,
**kwargs
)
)
def forward(
self,
x: Tensor,
states: Optional[List[Tensor]] = None,
offset: Union[int, Tensor] = 0,
):
B, T, C = x.shape
if states is None:
states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))]
positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period)
new_state: List[Tensor] = []
x = self.norm_in(x)
x = x + pos_emb
for layer_state, layer in zip(states, self.layers):
x, new_layer_state = layer(x, layer_state, self.past_context)
new_layer_state = torch.cat([layer_state, new_layer_state], dim=1)
new_state.append(new_layer_state[:, -self.past_context :, :])
return x, new_state, offset + T

View File

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from .vq import QuantizedResult, ResidualVectorQuantizer

View File

@ -0,0 +1,311 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Arithmetic coder."""
import io
import math
import random
from typing import IO, Any, List, Optional
import torch
from torch import Tensor
from ..binary import BitPacker, BitUnpacker
def build_stable_quantized_cdf(
pdf: Tensor,
total_range_bits: int,
roundoff: float = 1e-8,
min_range: int = 2,
check: bool = True,
) -> Tensor:
"""Turn the given PDF into a quantized CDF that splits
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
to the PDF.
Args:
pdf (Tensor): probability distribution, shape should be `[N]`.
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
during the coding process is `[0, 2 ** total_range_bits - 1]`.
roundoff (float): will round the pdf up to that level to remove difference coming
from e.g. evaluating the Language Model on different architectures.
min_range (int): minimum range width. Should always be at least 2 for numerical
stability. Use this to avoid pathological behavior is a value
that is expected to be rare actually happens in real life.
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
"""
pdf = pdf.detach()
if roundoff:
pdf = (pdf / roundoff).floor() * roundoff
# interpolate with uniform distribution to achieve desired minimum probability.
total_range = 2**total_range_bits
cardinality = len(pdf)
alpha = min_range * cardinality / total_range
assert alpha <= 1, "you must reduce min_range"
ranges = (((1 - alpha) * total_range) * pdf).floor().long()
ranges += min_range
quantized_cdf = torch.cumsum(ranges, dim=-1)
if min_range < 2:
raise ValueError("min_range must be at least 2.")
if check:
assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
if (
(quantized_cdf[1:] - quantized_cdf[:-1]) < min_range
).any() or quantized_cdf[0] < min_range:
raise ValueError("You must increase your total_range_bits.")
return quantized_cdf
class ArithmeticCoder:
"""ArithmeticCoder,
Let us take a distribution `p` over `N` symbols, and assume we have a stream
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
sequence `(s_t)` by doing the following:
1) Initialize the current range to` [0 ** 2 B - 1]`.
2) For each time step t, split the current range into contiguous chunks,
one for each possible outcome, with size roughly proportional to `p`.
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
would be `{[0, 2], [3, 3]}`.
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
4) When done encoding all the values, just select any value remaining in the range.
You will notice that this procedure can fail: for instance if at any point in time
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
possible outcome. Intuitively, the more likely a value is, the less the range width
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
coding scheme, likely outcomes would take less bits, and more of them can be coded
with a fixed budget.
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
when the current range decreases below a given limit (given by `total_range_bits`), without
having to redo all the computations. If we encode mostly likely values, we will seldom
need to inject new bits, but a single rare value can deplete our stock of entropy!
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
code works for any sequence `(p_t)` possibly different for each timestep.
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
the KL between the true distribution and `p_t`, the most efficient the coding will be.
Args:
fo (IO[bytes]): file-like object to which the bytes will be written to.
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
Any time the current range width fall under this limit, new bits will
be injected to rescale the initial range.
"""
def __init__(self, fo: IO[bytes], total_range_bits: int = 24):
assert total_range_bits <= 30
self.total_range_bits = total_range_bits
self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
self.low: int = 0
self.high: int = 0
self.max_bit: int = -1
self._dbg: List[Any] = []
self._dbg2: List[Any] = []
@property
def delta(self) -> int:
"""Return the current range width."""
return self.high - self.low + 1
def _flush_common_prefix(self):
# If self.low and self.high start with the sames bits,
# those won't change anymore as we always just increase the range
# by powers of 2, and we can flush them out to the bit stream.
assert self.high >= self.low, (self.low, self.high)
assert self.high < 2 ** (self.max_bit + 1)
while self.max_bit >= 0:
b1 = self.low >> self.max_bit
b2 = self.high >> self.max_bit
if b1 == b2:
self.low -= b1 << self.max_bit
self.high -= b1 << self.max_bit
assert self.high >= self.low, (self.high, self.low, self.max_bit)
assert self.low >= 0
self.max_bit -= 1
self.packer.push(b1)
else:
break
def push(self, symbol: int, quantized_cdf: Tensor):
"""Push the given symbol on the stream, flushing out bits
if possible.
Args:
symbol (int): symbol to encode with the AC.
quantized_cdf (Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate.
"""
while self.delta < 2**self.total_range_bits:
self.low *= 2
self.high = self.high * 2 + 1
self.max_bit += 1
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
range_high = quantized_cdf[symbol].item() - 1
effective_low = int(
math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
)
effective_high = int(
math.floor(range_high * (self.delta / (2**self.total_range_bits)))
)
assert self.low <= self.high
self.high = self.low + effective_high
self.low = self.low + effective_low
assert self.low <= self.high, (
effective_low,
effective_high,
range_low,
range_high,
)
self._dbg.append((self.low, self.high))
self._dbg2.append((self.low, self.high))
outs = self._flush_common_prefix()
assert self.low <= self.high
assert self.max_bit >= -1
assert self.max_bit <= 61, self.max_bit
return outs
def flush(self):
"""Flush the remaining information to the stream."""
while self.max_bit >= 0:
b1 = (self.low >> self.max_bit) & 1
self.packer.push(b1)
self.max_bit -= 1
self.packer.flush()
class ArithmeticDecoder:
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
Note that this must be called with **exactly** the same parameters and sequence
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
If the AC encoder current range is [L, H], with `L` and `H` having the some common
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
and we will need to read new bits from the stream and repeat the process.
"""
def __init__(self, fo: IO[bytes], total_range_bits: int = 24):
self.total_range_bits = total_range_bits
self.low: int = 0
self.high: int = 0
self.current: int = 0
self.max_bit: int = -1
self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
# Following is for debugging
self._dbg: List[Any] = []
self._dbg2: List[Any] = []
self._last: Any = None
@property
def delta(self) -> int:
return self.high - self.low + 1
def _flush_common_prefix(self):
# Given the current range [L, H], if both have a common prefix,
# we know we can remove it from our representation to avoid handling large numbers.
while self.max_bit >= 0:
b1 = self.low >> self.max_bit
b2 = self.high >> self.max_bit
if b1 == b2:
self.low -= b1 << self.max_bit
self.high -= b1 << self.max_bit
self.current -= b1 << self.max_bit
assert self.high >= self.low
assert self.low >= 0
self.max_bit -= 1
else:
break
def pull(self, quantized_cdf: Tensor) -> Optional[int]:
"""Pull a symbol, reading as many bits from the stream as required.
This returns `None` when the stream has been exhausted.
Args:
quantized_cdf (Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate. This must be **exatly**
the same cdf as the one used at encoding time.
"""
while self.delta < 2**self.total_range_bits:
bit = self.unpacker.pull()
if bit is None:
return None
self.low *= 2
self.high = self.high * 2 + 1
self.current = self.current * 2 + bit
self.max_bit += 1
def bin_search(low_idx: int, high_idx: int):
# Binary search is not just for coding interviews :)
if high_idx < low_idx:
raise RuntimeError("Binary search failed")
mid = (low_idx + high_idx) // 2
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
range_high = quantized_cdf[mid].item() - 1
effective_low = int(
math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
)
effective_high = int(
math.floor(range_high * (self.delta / (2**self.total_range_bits)))
)
low = effective_low + self.low
high = effective_high + self.low
if self.current >= low:
if self.current <= high:
return (mid, low, high, self.current)
else:
return bin_search(mid + 1, high_idx)
else:
return bin_search(low_idx, mid - 1)
self._last = (self.low, self.high, self.current, self.max_bit)
sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
self._dbg.append((self.low, self.high, self.current))
self._flush_common_prefix()
self._dbg2.append((self.low, self.high, self.current))
return sym
def test():
torch.manual_seed(1234)
random.seed(1234)
for _ in range(4):
pdfs = []
cardinality = random.randrange(4000)
steps = random.randrange(100, 500)
fo = io.BytesIO()
encoder = ArithmeticCoder(fo)
symbols = []
for step in range(steps):
pdf = torch.softmax(torch.randn(cardinality), dim=0)
pdfs.append(pdf)
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
symbol = torch.multinomial(pdf, 1).item()
symbols.append(symbol)
encoder.push(symbol, q_cdf)
encoder.flush()
fo.seek(0)
decoder = ArithmeticDecoder(fo)
for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
decoded_symbol = decoder.pull(q_cdf)
assert decoded_symbol == symbol, idx
assert decoder.pull(torch.zeros(1)) is None
if __name__ == "__main__":
test()

View File

@ -0,0 +1,377 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# This implementation is inspired from
# https://github.com/lucidrains/vector-quantize-pytorch
# which is released under MIT License. Hereafter, the original license:
# MIT License
#
# Copyright (c) 2020 Phil Wang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Core vector quantization implementation."""
from typing import Any, Callable, Optional, Union
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn
from .distrib import broadcast_tensors
def default(val: Any, d: Any) -> Any:
return val if val is not None else d
def ema_inplace(moving_avg, new, decay: float):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
return (x + epsilon) / (x.sum() + n_categories * epsilon)
def uniform_init(*shape: int):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def sample_vectors(samples, num: int):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
def kmeans(samples, num_clusters: int, num_iters: int = 10):
dim, dtype = samples.shape[-1], samples.dtype
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
dists = -(diffs**2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
class EuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
kmeans_init: int = False,
kmeans_iters: int = 10,
decay: float = 0.99,
epsilon: float = 1e-5,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.decay = decay
init_fn: Union[Callable[..., torch.Tensor], Any] = (
uniform_init if not kmeans_init else torch.zeros
)
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
self.register_buffer("cluster_size", torch.zeros(codebook_size))
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
@torch.jit.ignore
def init_embed_(self, data):
if self.inited:
return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
broadcast_tensors(self.buffers())
def preprocess(self, x):
x = rearrange(x, "... d -> (...) d")
return x
def quantize(self, x):
embed = self.embed.t()
dist = -(
x.pow(2).sum(1, keepdim=True)
- 2 * x @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
return embed_ind
def postprocess_emb(self, embed_ind, shape):
return embed_ind.view(*shape[:-1])
def dequantize(self, embed_ind):
quantize = F.embedding(embed_ind, self.embed)
return quantize
def encode(self, x):
shape = x.shape
# pre-process
x = self.preprocess(x)
# quantize
embed_ind = self.quantize(x)
# post-process
embed_ind = self.postprocess_emb(embed_ind, shape)
return embed_ind
def decode(self, embed_ind):
quantize = self.dequantize(embed_ind)
return quantize
def forward(self, x):
shape, dtype = x.shape, x.dtype
x = self.preprocess(x)
self.init_embed_(x)
embed_ind = self.quantize(x)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = self.postprocess_emb(embed_ind, shape)
quantize = self.dequantize(embed_ind)
if self.training:
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self.expire_codes_(x)
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
return quantize, embed_ind
class VectorQuantization(nn.Module):
"""Vector quantization implementation.
Currently supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def __init__(
self,
dim: int,
codebook_size: int,
codebook_dim: Optional[int] = None,
decay: float = 0.99,
epsilon: float = 1e-5,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
commitment_weight: float = 1.0,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
)
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(
dim=_codebook_dim,
codebook_size=codebook_size,
kmeans_init=kmeans_init,
kmeans_iters=kmeans_iters,
decay=decay,
epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code,
)
self.codebook_size = codebook_size
@property
def codebook(self):
return self._codebook.embed
def encode(self, x):
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
embed_in = self._codebook.encode(x)
return embed_in
def decode(self, embed_ind):
quantize = self._codebook.decode(embed_ind)
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize
def forward(self, x):
device = x.device
x = rearrange(x, "b d n -> b n d")
x = self.project_in(x)
quantize, embed_ind = self._codebook(x)
if self.training:
quantize = x + (quantize - x).detach()
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize, embed_ind, loss
class ResidualVectorQuantization(nn.Module):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
def forward(self, x, n_q: Optional[int] = None):
quantized_out = 0.0
residual = x
all_losses = []
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
quantized, indices, loss = layer(residual)
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses
def encode(
self, x: torch.Tensor, n_q: Optional[int] = None, st: Optional[int] = None
) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
st = st or 0
for layer in self.layers[st:n_q]: # 设置解码的起止layer
indices = layer.encode(residual)
quantized = layer.decode(indices)
residual = residual - quantized
all_indices.append(indices)
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out

View File

@ -0,0 +1,126 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch distributed utilities."""
from typing import Dict, Iterable, List
import torch
from torch import distributed as dist
def rank():
if dist.is_initialized():
return dist.get_rank()
else:
return 0
def world_size():
if dist.is_initialized():
return dist.get_world_size()
else:
return 1
def is_distributed():
return world_size() > 1
def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM):
if is_distributed():
return dist.all_reduce(tensor, op)
def _is_complex_or_float(tensor):
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
def _check_number_of_params(params: List[torch.Tensor]):
# utility function to check that the number of params in all workers is the same,
# and thus avoid a deadlock with distributed all reduce.
if not is_distributed() or not params:
return
# print('params[0].device ', params[0].device)
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
all_reduce(tensor)
if tensor.item() != len(params) * world_size():
# If not all the workers have the same number, for at least one of them,
# this inequality will be verified.
raise RuntimeError(
f"Mismatch in number of params: ours is {len(params)}, "
"at least one worker has a different one."
)
def broadcast_tensors(tensors: Iterable[torch.Tensor], src: int = 0):
"""Broadcast the tensors from the given parameters to all workers.
This can be used to ensure that all workers have the same model to start with.
"""
if not is_distributed():
return
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
_check_number_of_params(tensors)
handles = []
for tensor in tensors:
# src = int(rank()) # added code
handle = dist.broadcast(tensor.data, src=src, async_op=True)
handles.append(handle)
for handle in handles:
handle.wait()
def sync_buffer(buffers, average=True):
"""
Sync grad for buffers. If average is False, broadcast instead of averaging.
"""
if not is_distributed():
return
handles = []
for buffer in buffers:
if torch.is_floating_point(buffer.data):
if average:
handle = dist.all_reduce(
buffer.data, op=dist.ReduceOp.SUM, async_op=True
)
else:
handle = dist.broadcast(buffer.data, src=0, async_op=True)
handles.append((buffer, handle))
for buffer, handle in handles:
handle.wait()
if average:
buffer.data /= world_size
def sync_grad(params):
"""
Simpler alternative to DistributedDataParallel, that doesn't rely
on any black magic. For simple models it can also be as fast.
Just call this on your model parameters after the call to backward!
"""
if not is_distributed():
return
handles = []
for p in params:
if p.grad is not None:
handle = dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM, async_op=True)
handles.append((p, handle))
for p, handle in handles:
handle.wait()
p.grad.data /= world_size()
def average_metrics(metrics: Dict[str, float], count=1.0):
"""Average a dictionary of metrics across all workers, using the optional
`count` as unormalized weight.
"""
if not is_distributed():
return metrics
keys, values = zip(*metrics.items())
device = "cuda" if torch.cuda.is_available() else "cpu"
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
tensor *= count
all_reduce(tensor)
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
return dict(zip(keys, averaged))

View File

@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Residual vector quantizer implementation."""
import math
from dataclasses import dataclass, field
from typing import Optional
import torch
from torch import Tensor, nn
from .core_vq import ResidualVectorQuantization
@dataclass
class QuantizedResult:
quantized: Tensor
codes: Tensor
bandwidth: Tensor # bandwidth in kb/s used, per batch item.
penalty: Optional[Tensor] = None
metrics: dict = field(default_factory=dict)
class ResidualVectorQuantizer(nn.Module):
"""Residual Vector Quantizer.
Args:
dimension (int): Dimension of the codebooks.
n_q (int): Number of residual vector quantizers used.
bins (int): Codebook size.
decay (float): Decay for exponential moving average over the codebooks.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dimension: int = 256,
n_q: int = 8,
bins: int = 1024,
decay: float = 0.99,
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.n_q = n_q
self.dimension = dimension
self.bins = bins
self.decay = decay
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.threshold_ema_dead_code = threshold_ema_dead_code
self.vq = ResidualVectorQuantization(
dim=self.dimension,
codebook_size=self.bins,
num_quantizers=self.n_q,
decay=self.decay,
kmeans_init=self.kmeans_init,
kmeans_iters=self.kmeans_iters,
threshold_ema_dead_code=self.threshold_ema_dead_code,
)
def forward(
self, x: Tensor, sample_rate: int, bandwidth: Optional[float] = None
) -> QuantizedResult:
"""Residual vector quantization on the given input tensor.
Args:
x (Tensor): Input tensor.
sample_rate (int): Sample rate of the input tensor.
bandwidth (float): Target bandwidth.
Returns:
QuantizedResult:
The quantized (or approximately quantized) representation with
the associated bandwidth and any penalty term for the loss.
"""
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
bw = torch.tensor(n_q * bw_per_q).to(x)
return quantized, codes, bw, torch.mean(commit_loss)
# return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
def get_num_quantizers_for_bandwidth(
self, sample_rate: int, bandwidth: Optional[float] = None
) -> int:
"""Return n_q based on specified target bandwidth."""
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
n_q = self.n_q
if bandwidth and bandwidth > 0.0:
n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
return n_q
def get_bandwidth_per_quantizer(self, sample_rate: int):
"""Return bandwidth per quantizer for a given input sample rate."""
return math.log2(self.bins) * sample_rate / 1000
def encode(
self,
x: Tensor,
sample_rate: int,
bandwidth: Optional[float] = None,
st: Optional[int] = None,
) -> Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer.
"""
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
st = st or 0
codes = self.vq.encode(x, n_q=n_q, st=st)
return codes
def decode(self, codes: Tensor) -> Tensor:
"""Decode the given codes to the quantized representation."""
quantized = self.vq.decode(codes)
return quantized

View File

@ -0,0 +1,902 @@
import argparse
import itertools
import logging
import math
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from encodec import Encodec
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from utils import MetricsTracker, plot_feature, save_checkpoint
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, setup_logger, str2bool
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=500,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vits/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lr", type=float, default=3.0e-4, help="The base learning rate."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
parser.add_argument(
"--inf-check",
type=str2bool,
default=False,
help="Add hooks to check for infinite module outputs and gradients.",
)
parser.add_argument(
"--save-every-n",
type=int,
default=20,
help="""Save checkpoint after processing this number of epochs"
periodically. We save checkpoint to exp-dir/ whenever
params.cur_epoch % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
Since it will take around 1000 epochs, we suggest using a large
save_every_n to save disk space.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
"""
params = AttributeDict(
{
# training params
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": -1, # 0
"log_interval": 50,
"valid_interval": 200,
"env_info": get_env_info(),
"sampling_rate": 24000,
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
"lambda_wav": 100.0, # loss scaling coefficient for waveform loss
"lambda_feat": 1.0, # loss scaling coefficient for feat loss
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
"lambda_com": 1000.0, # loss scaling coefficient for commitment loss
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict, model: nn.Module
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
If params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
Returns:
Return a dict containing previously saved training info.
"""
if params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
saved_params = load_checkpoint(filename, model=model)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def get_model(params: AttributeDict) -> nn.Module:
"""Get the model based on the configuration."""
from discriminators import (
MultiPeriodDiscriminator,
MultiScaleDiscriminator,
MultiScaleSTFTDiscriminator,
)
from modules.seanet import SEANetDecoder, SEANetEncoder
from quantization import ResidualVectorQuantizer
generator_params = {
"generator_n_filters": 32,
"dimension": 512,
"ratios": [2, 2, 2, 4],
"target_bandwidths": [7.5, 15],
"bins": 1024,
}
discriminator_params = {
"stft_discriminator_n_filters": 32,
}
params.update(generator_params)
params.update(discriminator_params)
hop_length = np.prod(params.ratios)
n_q = int(
1000
* params.target_bandwidths[-1]
// (math.ceil(params.sample_rate / hop_length) * 10)
)
encoder = SEANetEncoder(
n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios
)
decoder = SEANetDecoder(
n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios
)
quantizer = ResidualVectorQuantizer(
dimension=params.dimension, n_q=n_q, bins=params.bins
)
model = Encodec(
params=params,
sample_rate=params.sampling_rate,
target_bandwidths=params.target_bandwidths,
encoder=encoder,
quantizer=quantizer,
decoder=decoder,
multi_scale_discriminator=MultiScaleDiscriminator(),
multi_period_discriminator=MultiPeriodDiscriminator(),
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(),
)
return model
def prepare_input(
batch: dict,
device: torch.device,
):
"""Parse batch data"""
audio = batch["audio"].to(device, memory_format=torch.contiguous_format)
features = batch["features"].to(device, memory_format=torch.contiguous_format)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
return audio, audio_lens, features, features_lens
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
optimizer_g: Optimizer,
optimizer_d: Optimizer,
scheduler_g: LRSchedulerType,
scheduler_d: LRSchedulerType,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model to be trained.
optimizer_g:
The optimizer for generator.
optimizer_d:
The optimizer for discriminator.
scheduler_g:
The learning rate scheduler for generator, we call step() every epoch.
scheduler_d:
The learning rate scheduler for discriminator, we call step() every epoch.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
params=params,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=0,
)
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["tokens"])
(
audio,
audio_lens,
_,
_,
) = prepare_input(batch, device)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
try:
with autocast(enabled=params.use_fp16):
# forward discriminator
loss_d, stats_d = model(
speech=audio,
speech_lengths=audio_lens,
global_step=params.batch_idx_train,
return_sample=False,
forward_generator=False,
)
for k, v in stats_d.items():
loss_info[k] = v * batch_size
# update discriminator
optimizer_d.zero_grad()
scaler.scale(loss_d).backward()
scaler.step(optimizer_d)
with autocast(enabled=params.use_fp16):
# forward generator
loss_g, stats_g = model(
speech=audio,
speech_lengths=audio_lens,
global_step=params.batch_idx_train,
forward_generator=True,
return_sample=params.batch_idx_train % params.log_interval == 0,
)
for k, v in stats_g.items():
if "returned_sample" not in k:
loss_info[k] = v * batch_size
# update generator
optimizer_g.zero_grad()
scaler.scale(loss_g).backward()
scaler.step(optimizer_g)
scaler.update()
# summary stats
tot_loss = tot_loss + loss_info
except: # noqa
save_bad_model()
raise
if params.print_diagnostics and batch_idx == 5:
return
if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if params.batch_idx_train % params.log_interval == 0:
cur_lr_g = max(scheduler_g.get_last_lr())
cur_lr_d = max(scheduler_d.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
f"loss[{loss_info}], tot_loss[{tot_loss}], "
f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate_g", cur_lr_g, params.batch_idx_train
)
tb_writer.add_scalar(
"train/learning_rate_d", cur_lr_d, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if "returned_sample" in stats_g:
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
tb_writer.add_audio(
"train/speech_hat_",
speech_hat_,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"train/speech_",
speech_,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_image(
"train/mel_hat_",
plot_feature(mel_hat_),
params.batch_idx_train,
dataformats="HWC",
)
tb_writer.add_image(
"train/mel_",
plot_feature(mel_),
params.batch_idx_train,
dataformats="HWC",
)
if (
params.batch_idx_train % params.valid_interval == 0
and not params.print_diagnostics
):
logging.info("Computing validation loss")
valid_info, (speech_hat, speech) = compute_validation_loss(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_audio(
"train/valdi_speech_hat",
speech_hat,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"train/valdi_speech",
speech,
params.batch_idx_train,
params.sampling_rate,
)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
rank: int = 0,
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
"""Run the validation process."""
model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations
tot_loss = MetricsTracker()
returned_sample = None
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["tokens"])
(
audio,
audio_lens,
_,
_,
) = prepare_input(batch, device)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
# forward discriminator
loss_d, stats_d = model(
speech=audio,
speech_lengths=audio_lens,
global_step=params.batch_idx_train,
return_sample=False,
forward_generator=False,
)
assert loss_d.requires_grad is False
for k, v in stats_d.items():
loss_info[k] = v * batch_size
# forward generator
loss_g, stats_g = model(
speech=audio,
speech_lengths=audio_lens,
global_step=params.batch_idx_train,
forward_generator=True,
return_sample=batch_idx == 0,
)
assert loss_g.requires_grad is False
for k, v in stats_g.items():
loss_info[k] = v * batch_size
# summary stats
tot_loss = tot_loss + loss_info
# infer for first batch:
if batch_idx == 0 and rank == 0:
speech_hat_, speech_, _, _ = stats_g["returned_sample"]
returned_sample = (speech_hat_, speech_)
if world_size > 1:
tot_loss.reduce(device)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss, returned_sample
def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer_g: torch.optim.Optimizer,
optimizer_d: torch.optim.Optimizer,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
(
audio,
audio_lens,
_,
_,
) = prepare_input(batch, device)
try:
# for discriminator
with autocast(enabled=params.use_fp16):
loss_d, stats_d = model(
speech=audio,
speech_lengths=audio_lens,
global_step=params.batch_idx_train,
return_sample=False,
forward_generator=False,
)
optimizer_d.zero_grad()
loss_d.backward()
# for generator
with autocast(enabled=params.use_fp16):
loss_g, stats_g = model(
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
global_step=params.batch_idx_train,
return_sample=False,
)
optimizer_g.zero_grad()
loss_g.backward()
except Exception as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
vctk = VctkTtsDataModule(args)
train_cuts = vctk.train_cuts()
logging.info(params)
logging.info("About to create model")
model = get_model(params)
encoder = model.encoder
decoder = model.decoder
quantizer = model.quantizer
multi_scale_discriminator = model.multi_scale_discriminator
multi_period_discriminator = model.multi_period_discriminator
multi_scale_stft_discriminator = model.multi_scale_stft_discriminator
num_param_e = sum([p.numel() for p in encoder.parameters()])
logging.info(f"Number of parameters in encoder: {num_param_e}")
num_param_d = sum([p.numel() for p in decoder.parameters()])
logging.info(f"Number of parameters in decoder: {num_param_d}")
num_param_q = sum([p.numel() for p in quantizer.parameters()])
logging.info(f"Number of parameters in quantizer: {num_param_q}")
num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()])
logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}")
num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()])
logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}")
num_param_dstft = sum(
[p.numel() for p in multi_scale_stft_discriminator.parameters()]
)
logging.info(
f"Number of parameters in multi_scale_stft_discriminator: {num_param_dstft}"
)
logging.info(
f"Total number of parameters: {num_param_e + num_param_d + num_param_q + num_param_ds + num_param_dp + num_param_dstft}"
)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer_g = torch.optim.AdamW(
itertools.chain(
encoder.parameters(),
quantizer.parameters(),
decoder.parameters(),
),
lr=params.lr,
betas=(0.5, 0.9),
)
optimizer_d = torch.optim.AdamW(
itertools.chain(
multi_scale_stft_discriminator.parameters(),
multi_scale_discriminator.parameters(),
multi_period_discriminator.parameters(),
),
lr=params.lr,
betas=(0.5, 0.9),
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
if checkpoints is not None:
# load state_dict for optimizers
if "optimizer_g" in checkpoints:
logging.info("Loading optimizer_g state dict")
optimizer_g.load_state_dict(checkpoints["optimizer_g"])
if "optimizer_d" in checkpoints:
logging.info("Loading optimizer_d state dict")
optimizer_d.load_state_dict(checkpoints["optimizer_d"])
# load state_dict for schedulers
if "scheduler_g" in checkpoints:
logging.info("Loading scheduler_g state dict")
scheduler_g.load_state_dict(checkpoints["scheduler_g"])
if "scheduler_d" in checkpoints:
logging.info("Loading scheduler_d state dict")
scheduler_d.load_state_dict(checkpoints["scheduler_d"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
if params.inf_check:
register_inf_check_hooks(model)
train_dl = vctk.train_dataloaders(train_cuts)
valid_cuts = vctk.valid_cuts()
valid_dl = vctk.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"Start epoch {epoch}")
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
params.cur_epoch = epoch
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
train_one_epoch(
params=params,
model=model,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint(
filename=filename,
params=params,
model=model,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
if rank == 0:
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
# step per epoch
scheduler_g.step()
scheduler_d.step()
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
VctkTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../vctk/TTS/vits/utils.py