mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
init commit
This commit is contained in:
parent
cea0dbe7b1
commit
dd82686a0f
160
egs/libritts/ASR/local/compute_fbank_libritts.py
Executable file
160
egs/libritts/ASR/local/compute_fbank_libritts.py
Executable file
@ -0,0 +1,160 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao,)
|
||||
# 2024 The Chinese Univ. of HK (authors: Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the LibriTTS dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=24000,
|
||||
help="""Sampling rate of the audio for computing fbank, the default value for LibriTTS is 24000, audio files will be resampled if a different sample rate is provided""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compute_fbank_libritts(
|
||||
dataset: Optional[str] = None,
|
||||
sampling_rate: int = 24000,
|
||||
perturb_speed: Optional[bool] = True,
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(32, os.cpu_count())
|
||||
|
||||
num_mel_bins = 80
|
||||
|
||||
if dataset is None:
|
||||
dataset_parts = (
|
||||
"dev-clean",
|
||||
"dev-other",
|
||||
"test-clean",
|
||||
"test-other",
|
||||
"train-clean-100",
|
||||
"train-clean-360",
|
||||
"train-other-500",
|
||||
)
|
||||
else:
|
||||
dataset_parts = dataset.split(" ", -1)
|
||||
|
||||
prefix = "libritts"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
if (output_dir / cuts_filename).is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if sampling_rate != 24000:
|
||||
logging.info(f"Resampling audio to {sampling_rate}")
|
||||
cut_set = cut_set.resample(sampling_rate)
|
||||
if "train" in partition:
|
||||
if perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set
|
||||
+ cut_set.perturb_speed(0.9)
|
||||
+ cut_set.perturb_speed(1.1)
|
||||
)
|
||||
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(output_dir / cuts_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
compute_fbank_libritts(
|
||||
dataset=args.dataset,
|
||||
sampling_rate=args.sampling_rate,
|
||||
perturb_speed=args.perturb_speed,
|
||||
)
|
1
egs/libritts/ASR/local/compute_fbank_musan.py
Symbolic link
1
egs/libritts/ASR/local/compute_fbank_musan.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compute_fbank_musan.py
|
107
egs/libritts/ASR/local/compute_spectrogram_libritts.py
Executable file
107
egs/libritts/ASR/local/compute_spectrogram_libritts.py
Executable file
@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao,)
|
||||
# 2024 The Chinese Univ. of HK (authors: Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the VCTK dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/spectrogram.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
LilcomChunkyWriter,
|
||||
Spectrogram,
|
||||
SpectrogramConfig,
|
||||
load_manifest,
|
||||
)
|
||||
from lhotse.audio import RecordingSet
|
||||
from lhotse.supervision import SupervisionSet
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_spectrogram_libritts():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/spectrogram")
|
||||
num_jobs = min(32, os.cpu_count())
|
||||
|
||||
sampling_rate = 24000
|
||||
frame_length = 1024 / sampling_rate # (in second)
|
||||
frame_shift = 256 / sampling_rate # (in second)
|
||||
use_fft_mag = True
|
||||
|
||||
prefix = "libritts"
|
||||
suffix = "jsonl.gz"
|
||||
partition = "all"
|
||||
|
||||
recordings = load_manifest(
|
||||
src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet
|
||||
).resample(sampling_rate=sampling_rate)
|
||||
supervisions = load_manifest(
|
||||
src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet
|
||||
)
|
||||
|
||||
config = SpectrogramConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
use_fft_mag=use_fft_mag,
|
||||
)
|
||||
extractor = Spectrogram(config)
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
if (output_dir / cuts_filename).is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
return
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=recordings, supervisions=supervisions
|
||||
)
|
||||
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(output_dir / cuts_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
compute_spectrogram_libritts()
|
341
egs/libritts/ASR/local/display_manifest_statistics.py
Executable file
341
egs/libritts/ASR/local/display_manifest_statistics.py
Executable file
@ -0,0 +1,341 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
# 2024 The Chinese Univ. of HK (authors: Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file displays duration statistics of utterances in a manifest.
|
||||
You can use the displayed value to choose minimum/maximum duration
|
||||
to remove short and long utterances during the training.
|
||||
"""
|
||||
|
||||
|
||||
from lhotse import load_manifest_lazy
|
||||
|
||||
|
||||
def main():
|
||||
paths = [
|
||||
"./data/fbank/libritts_cuts_train-clean-100.jsonl.gz",
|
||||
"./data/fbank/libritts_cuts_train-clean-360.jsonl.gz",
|
||||
"./data/fbank/libritts_cuts_train-other-500.jsonl.gz",
|
||||
"./data/fbank/libritts_cuts_dev-clean.jsonl.gz",
|
||||
"./data/fbank/libritts_cuts_dev-other.jsonl.gz",
|
||||
"./data/fbank/libritts_cuts_test-clean.jsonl.gz",
|
||||
"./data/fbank/libritts_cuts_test-other.jsonl.gz",
|
||||
]
|
||||
for path in paths:
|
||||
cuts = load_manifest_lazy(path)
|
||||
cuts.describe()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
"""
|
||||
./data/fbank/libritts_cuts_train-clean-100.jsonl.gz statistics:
|
||||
________________________________________
|
||||
_ Cuts count: _ 33236 _
|
||||
________________________________________
|
||||
_ Total duration (hh:mm:ss) _ 53:47:18 _
|
||||
________________________________________
|
||||
_ mean _ 5.8 _
|
||||
________________________________________
|
||||
_ std _ 4.6 _
|
||||
________________________________________
|
||||
_ min _ 0.2 _
|
||||
________________________________________
|
||||
_ 25% _ 2.4 _
|
||||
________________________________________
|
||||
_ 50% _ 4.5 _
|
||||
________________________________________
|
||||
_ 75% _ 7.9 _
|
||||
________________________________________
|
||||
_ 99% _ 21.4 _
|
||||
________________________________________
|
||||
_ 99.5% _ 23.7 _
|
||||
________________________________________
|
||||
_ 99.9% _ 27.8 _
|
||||
________________________________________
|
||||
_ max _ 33.2 _
|
||||
________________________________________
|
||||
_ Recordings available: _ 33236 _
|
||||
________________________________________
|
||||
_ Features available: _ 33236 _
|
||||
________________________________________
|
||||
_ Supervisions available: _ 33236 _
|
||||
________________________________________
|
||||
SUPERVISION custom fields:
|
||||
Speech duration statistics:
|
||||
__________________________________________________________________
|
||||
_ Total speech duration _ 53:47:18 _ 100.00% of recording _
|
||||
__________________________________________________________________
|
||||
_ Total speaking time duration _ 53:47:18 _ 100.00% of recording _
|
||||
__________________________________________________________________
|
||||
_ Total silence duration _ 00:00:01 _ 0.00% of recording _
|
||||
__________________________________________________________________
|
||||
|
||||
./data/fbank/libritts_cuts_train-clean-360.jsonl.gz statistics:
|
||||
_________________________________________
|
||||
_ Cuts count: _ 116500 _
|
||||
_________________________________________
|
||||
_ Total duration (hh:mm:ss) _ 191:17:42 _
|
||||
_________________________________________
|
||||
_ mean _ 5.9 _
|
||||
_________________________________________
|
||||
_ std _ 4.6 _
|
||||
_________________________________________
|
||||
_ min _ 0.1 _
|
||||
_________________________________________
|
||||
_ 25% _ 2.4 _
|
||||
_________________________________________
|
||||
_ 50% _ 4.6 _
|
||||
_________________________________________
|
||||
_ 75% _ 8.1 _
|
||||
_________________________________________
|
||||
_ 99% _ 21.3 _
|
||||
_________________________________________
|
||||
_ 99.5% _ 23.4 _
|
||||
_________________________________________
|
||||
_ 99.9% _ 27.4 _
|
||||
_________________________________________
|
||||
_ max _ 40.4 _
|
||||
_________________________________________
|
||||
_ Recordings available: _ 116500 _
|
||||
_________________________________________
|
||||
_ Features available: _ 116500 _
|
||||
_________________________________________
|
||||
_ Supervisions available: _ 116500 _
|
||||
_________________________________________
|
||||
SUPERVISION custom fields:
|
||||
Speech duration statistics:
|
||||
___________________________________________________________________
|
||||
_ Total speech duration _ 191:17:42 _ 100.00% of recording _
|
||||
___________________________________________________________________
|
||||
_ Total speaking time duration _ 191:17:42 _ 100.00% of recording _
|
||||
___________________________________________________________________
|
||||
_ Total silence duration _ 00:00:01 _ 0.00% of recording _
|
||||
___________________________________________________________________
|
||||
|
||||
./data/fbank/libritts_cuts_train-other-500.jsonl.gz statistics:
|
||||
_________________________________________
|
||||
_ Cuts count: _ 205043 _
|
||||
_________________________________________
|
||||
_ Total duration (hh:mm:ss) _ 310:04:36 _
|
||||
_________________________________________
|
||||
_ mean _ 5.4 _
|
||||
_________________________________________
|
||||
_ std _ 4.4 _
|
||||
_________________________________________
|
||||
_ min _ 0.1 _
|
||||
_________________________________________
|
||||
_ 25% _ 2.3 _
|
||||
_________________________________________
|
||||
_ 50% _ 4.2 _
|
||||
_________________________________________
|
||||
_ 75% _ 7.3 _
|
||||
_________________________________________
|
||||
_ 99% _ 20.6 _
|
||||
_________________________________________
|
||||
_ 99.5% _ 22.8 _
|
||||
_________________________________________
|
||||
_ 99.9% _ 27.4 _
|
||||
_________________________________________
|
||||
_ max _ 43.9 _
|
||||
_________________________________________
|
||||
_ Recordings available: _ 205043 _
|
||||
_________________________________________
|
||||
_ Features available: _ 205043 _
|
||||
_________________________________________
|
||||
_ Supervisions available: _ 205043 _
|
||||
_________________________________________
|
||||
SUPERVISION custom fields:
|
||||
Speech duration statistics:
|
||||
___________________________________________________________________
|
||||
_ Total speech duration _ 310:04:36 _ 100.00% of recording _
|
||||
___________________________________________________________________
|
||||
_ Total speaking time duration _ 310:04:36 _ 100.00% of recording _
|
||||
___________________________________________________________________
|
||||
_ Total silence duration _ 00:00:01 _ 0.00% of recording _
|
||||
___________________________________________________________________
|
||||
|
||||
./data/fbank/libritts_cuts_dev-clean.jsonl.gz statistics:
|
||||
________________________________________
|
||||
_ Cuts count: _ 5736 _
|
||||
________________________________________
|
||||
_ Total duration (hh:mm:ss) _ 08:58:13 _
|
||||
________________________________________
|
||||
_ mean _ 5.6 _
|
||||
________________________________________
|
||||
_ std _ 4.3 _
|
||||
________________________________________
|
||||
_ min _ 0.3 _
|
||||
________________________________________
|
||||
_ 25% _ 2.4 _
|
||||
________________________________________
|
||||
_ 50% _ 4.4 _
|
||||
________________________________________
|
||||
_ 75% _ 7.8 _
|
||||
________________________________________
|
||||
_ 99% _ 19.9 _
|
||||
________________________________________
|
||||
_ 99.5% _ 21.9 _
|
||||
________________________________________
|
||||
_ 99.9% _ 26.3 _
|
||||
________________________________________
|
||||
_ max _ 30.1 _
|
||||
________________________________________
|
||||
_ Recordings available: _ 5736 _
|
||||
________________________________________
|
||||
_ Features available: _ 5736 _
|
||||
________________________________________
|
||||
_ Supervisions available: _ 5736 _
|
||||
________________________________________
|
||||
SUPERVISION custom fields:
|
||||
Speech duration statistics:
|
||||
__________________________________________________________________
|
||||
_ Total speech duration _ 08:58:13 _ 100.00% of recording _
|
||||
__________________________________________________________________
|
||||
_ Total speaking time duration _ 08:58:13 _ 100.00% of recording _
|
||||
__________________________________________________________________
|
||||
_ Total silence duration _ 00:00:01 _ 0.00% of recording _
|
||||
__________________________________________________________________
|
||||
|
||||
./data/fbank/libritts_cuts_dev-other.jsonl.gz statistics:
|
||||
________________________________________
|
||||
_ Cuts count: _ 4613 _
|
||||
________________________________________
|
||||
_ Total duration (hh:mm:ss) _ 06:25:52 _
|
||||
________________________________________
|
||||
_ mean _ 5.0 _
|
||||
________________________________________
|
||||
_ std _ 4.1 _
|
||||
________________________________________
|
||||
_ min _ 0.3 _
|
||||
________________________________________
|
||||
_ 25% _ 2.2 _
|
||||
________________________________________
|
||||
_ 50% _ 3.8 _
|
||||
________________________________________
|
||||
_ 75% _ 6.5 _
|
||||
________________________________________
|
||||
_ 99% _ 19.7 _
|
||||
________________________________________
|
||||
_ 99.5% _ 24.5 _
|
||||
________________________________________
|
||||
_ 99.9% _ 31.0 _
|
||||
________________________________________
|
||||
_ max _ 32.6 _
|
||||
________________________________________
|
||||
_ Recordings available: _ 4613 _
|
||||
________________________________________
|
||||
_ Features available: _ 4613 _
|
||||
________________________________________
|
||||
_ Supervisions available: _ 4613 _
|
||||
________________________________________
|
||||
SUPERVISION custom fields:
|
||||
Speech duration statistics:
|
||||
__________________________________________________________________
|
||||
_ Total speech duration _ 06:25:52 _ 100.00% of recording _
|
||||
__________________________________________________________________
|
||||
_ Total speaking time duration _ 06:25:52 _ 100.00% of recording _
|
||||
__________________________________________________________________
|
||||
_ Total silence duration _ 00:00:01 _ 0.00% of recording _
|
||||
__________________________________________________________________
|
||||
|
||||
./data/fbank/libritts_cuts_test-clean.jsonl.gz statistics:
|
||||
________________________________________
|
||||
_ Cuts count: _ 4837 _
|
||||
________________________________________
|
||||
_ Total duration (hh:mm:ss) _ 08:34:09 _
|
||||
________________________________________
|
||||
_ mean _ 6.4 _
|
||||
________________________________________
|
||||
_ std _ 5.1 _
|
||||
________________________________________
|
||||
_ min _ 0.3 _
|
||||
________________________________________
|
||||
_ 25% _ 2.4 _
|
||||
________________________________________
|
||||
_ 50% _ 4.8 _
|
||||
________________________________________
|
||||
_ 75% _ 8.9 _
|
||||
________________________________________
|
||||
_ 99% _ 22.6 _
|
||||
________________________________________
|
||||
_ 99.5% _ 24.4 _
|
||||
________________________________________
|
||||
_ 99.9% _ 29.6 _
|
||||
________________________________________
|
||||
_ max _ 36.7 _
|
||||
________________________________________
|
||||
_ Recordings available: _ 4837 _
|
||||
________________________________________
|
||||
_ Features available: _ 4837 _
|
||||
________________________________________
|
||||
_ Supervisions available: _ 4837 _
|
||||
________________________________________
|
||||
SUPERVISION custom fields:
|
||||
Speech duration statistics:
|
||||
__________________________________________________________________
|
||||
_ Total speech duration _ 08:34:09 _ 100.00% of recording _
|
||||
__________________________________________________________________
|
||||
_ Total speaking time duration _ 08:34:09 _ 100.00% of recording _
|
||||
__________________________________________________________________
|
||||
_ Total silence duration _ 00:00:01 _ 0.00% of recording _
|
||||
__________________________________________________________________
|
||||
|
||||
./data/fbank/libritts_cuts_test-other.jsonl.gz statistics:
|
||||
________________________________________
|
||||
_ Cuts count: _ 5120 _
|
||||
________________________________________
|
||||
_ Total duration (hh:mm:ss) _ 06:41:31 _
|
||||
________________________________________
|
||||
_ mean _ 4.7 _
|
||||
________________________________________
|
||||
_ std _ 3.8 _
|
||||
________________________________________
|
||||
_ min _ 0.3 _
|
||||
________________________________________
|
||||
_ 25% _ 1.8 _
|
||||
________________________________________
|
||||
_ 50% _ 3.6 _
|
||||
________________________________________
|
||||
_ 75% _ 6.5 _
|
||||
________________________________________
|
||||
_ 99% _ 17.8 _
|
||||
________________________________________
|
||||
_ 99.5% _ 20.4 _
|
||||
________________________________________
|
||||
_ 99.9% _ 23.8 _
|
||||
________________________________________
|
||||
_ max _ 27.3 _
|
||||
________________________________________
|
||||
_ Recordings available: _ 5120 _
|
||||
________________________________________
|
||||
_ Features available: _ 5120 _
|
||||
________________________________________
|
||||
_ Supervisions available: _ 5120 _
|
||||
________________________________________
|
||||
SUPERVISION custom fields:
|
||||
Speech duration statistics:
|
||||
__________________________________________________________________
|
||||
_ Total speech duration _ 06:41:31 _ 100.00% of recording _
|
||||
__________________________________________________________________
|
||||
_ Total speaking time duration _ 06:41:31 _ 100.00% of recording _
|
||||
__________________________________________________________________
|
||||
_ Total silence duration _ 00:00:01 _ 0.00% of recording _
|
||||
__________________________________________________________________
|
||||
"""
|
71
egs/libritts/ASR/local/validate_manifest.py
Executable file
71
egs/libritts/ASR/local/validate_manifest.py
Executable file
@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao,)
|
||||
# 2024 The Chinese Univ. of HK (authors: Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script checks the following assumptions of the generated manifest:
|
||||
|
||||
- Single supervision per cut
|
||||
|
||||
We will add more checks later if needed.
|
||||
|
||||
Usage example:
|
||||
|
||||
python3 ./local/validate_manifest.py \
|
||||
./data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.dataset.speech_recognition import validate_for_asr
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"manifest",
|
||||
type=Path,
|
||||
help="Path to the manifest file",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
manifest = args.manifest
|
||||
logging.info(f"Validating {manifest}")
|
||||
|
||||
assert manifest.is_file(), f"{manifest} does not exist"
|
||||
cut_set = load_manifest(manifest)
|
||||
assert isinstance(cut_set, CutSet)
|
||||
|
||||
validate_for_asr(cut_set)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
108
egs/libritts/ASR/prepare.sh
Executable file
108
egs/libritts/ASR/prepare.sh
Executable file
@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
sampling_rate=24000
|
||||
perturb_speed=true
|
||||
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/LibriTTS,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS
|
||||
#
|
||||
if [ ! -d $dl_dir/LibriTTS ]; then
|
||||
lhotse download libritts $dl_dir
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/musan,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/musan $dl_dir/musan
|
||||
#
|
||||
if [ ! -d $dl_dir/musan ]; then
|
||||
lhotse download musan $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare LibriTTS manifest"
|
||||
# We assume that you have downloaded the LibriTTS corpus
|
||||
# to $dl_dir/LibriTTS
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.libritts.done ]; then
|
||||
lhotse prepare libritts $dl_dir/LibriTTS data/manifests
|
||||
touch data/manifests/.libritts.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to data/musan
|
||||
if [ ! -f data/manifests/.musan_manifests.done ]; then
|
||||
log "It may take 6 minutes"
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
touch data/manifests/.musan_manifests.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Compute Fbank for LibriTTS"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.libritts.done ]; then
|
||||
./local/compute_fbank_libritts.py \
|
||||
--sampling-rate $sampling_rate \
|
||||
--perturb-speed $perturb_speed
|
||||
touch data/fbank/.libritts.done
|
||||
fi
|
||||
|
||||
# Here we shuffle and combine the train-clean-100, train-clean-360 and
|
||||
# train-other-500 together to form the training set.
|
||||
if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then
|
||||
cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \
|
||||
<(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \
|
||||
<(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \
|
||||
shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ ! -e data/fbank/.libritts-validated.done ]; then
|
||||
log "Validating data/fbank for LibriTTS"
|
||||
./local/validate_manifest.py \
|
||||
data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
|
||||
touch data/fbank/.libritts-validated.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
if [ ! -f data/fbank/.msuan.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_musan.py
|
||||
touch data/fbank/.msuan.done
|
||||
fi
|
||||
fi
|
1
egs/libritts/ASR/shared
Symbolic link
1
egs/libritts/ASR/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared/
|
1
egs/libritts/ASR/zipformer/.gitignore
vendored
Normal file
1
egs/libritts/ASR/zipformer/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
swoosh.pdf
|
459
egs/libritts/ASR/zipformer/asr_datamodule.py
Normal file
459
egs/libritts/ASR/zipformer/asr_datamodule.py
Normal file
@ -0,0 +1,459 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriTTSAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. libritts test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""When enabled, use 960h LibriTTS.
|
||||
Otherwise, use the 100h subset.""",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)()
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_360_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-360 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_other_500_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-other-500 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_train-other-500.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_all_shuf_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled train-clean-100, \
|
||||
train-clean-360 and train-other-500 cuts"
|
||||
)
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz"
|
||||
)
|
1
egs/libritts/ASR/zipformer/attention_decoder.py
Symbolic link
1
egs/libritts/ASR/zipformer/attention_decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/attention_decoder.py
|
1
egs/libritts/ASR/zipformer/beam_search.py
Symbolic link
1
egs/libritts/ASR/zipformer/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
|
991
egs/libritts/ASR/zipformer/ctc_decode.py
Executable file
991
egs/libritts/ASR/zipformer/ctc_decode.py
Executable file
@ -0,0 +1,991 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Liyong Guo,
|
||||
# Quandong Wang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) ctc-greedy-search
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method ctc-greedy-search
|
||||
|
||||
(2) ctc-decoding
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method ctc-decoding
|
||||
|
||||
(3) 1best
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.6 \
|
||||
--decoding-method 1best
|
||||
|
||||
(4) nbest
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.6 \
|
||||
--decoding-method nbest
|
||||
|
||||
(5) nbest-rescoring
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.6 \
|
||||
--nbest-scale 1.0 \
|
||||
--lm-dir data/lm \
|
||||
--decoding-method nbest-rescoring
|
||||
|
||||
(6) whole-lattice-rescoring
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.6 \
|
||||
--nbest-scale 1.0 \
|
||||
--lm-dir data/lm \
|
||||
--decoding-method whole-lattice-rescoring
|
||||
|
||||
(7) attention-decoder-rescoring-no-ngram
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--use-attention-decoder 1 \
|
||||
--max-duration 100 \
|
||||
--decoding-method attention-decoder-rescoring-no-ngram
|
||||
|
||||
(8) attention-decoder-rescoring-with-ngram
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--use-attention-decoder 1 \
|
||||
--max-duration 100 \
|
||||
--hlg-scale 0.6 \
|
||||
--nbest-scale 1.0 \
|
||||
--lm-dir data/lm \
|
||||
--decoding-method attention-decoder-rescoring-with-ngram
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriTTSAsrDataModule
|
||||
from lhotse import set_caching_enabled
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import (
|
||||
ctc_greedy_search,
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder_no_ngram,
|
||||
rescore_with_attention_decoder_with_ngram,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="ctc-decoding",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
- (2) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
- (3) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (4) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (5) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||
the highest score is the decoding result.
|
||||
- (6) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||
is the decoding result.
|
||||
you have trained an RNN LM using ./rnn_lm/train.py
|
||||
- (7) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
- (8) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
|
||||
lattice, rescore them with the attention decoder.
|
||||
- (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
|
||||
rescored lattice, rescore them with the attention decoder.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""Number of paths for n-best based decoding method.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, and nbest-oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="""The scale to be applied to `lattice.scores`.
|
||||
It's needed if you use any kinds of n-best based rescoring.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, and nbest-oracle
|
||||
A smaller value results in more unique paths.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hlg-scale",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="""The scale to be applied to `hlg.scores`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-dir",
|
||||
type=str,
|
||||
default="data/lm",
|
||||
help="""The n-gram LM dir.
|
||||
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip-scoring",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Skip scoring, but still save the ASR output (for eval sets).""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_decoding_params() -> AttributeDict:
|
||||
"""Parameters for decoding."""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"frame_shift_ms": 10,
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
batch: dict,
|
||||
word_table: k2.SymbolTable,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if no rescoring is used, the key is the string `no_rescore`.
|
||||
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||
where `xxx` is the value of `lm_scale`. An example key is
|
||||
`lm_scale_0.7`
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
|
||||
- params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
|
||||
- params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
|
||||
- params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
|
||||
- params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
|
||||
rescoring.
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.decoding_method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.decoding_method is ctc-decoding.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
G:
|
||||
An LM. It is not None when params.decoding_method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict. Note: If it decodes to nothing, then return None.
|
||||
"""
|
||||
if HLG is not None:
|
||||
device = HLG.device
|
||||
else:
|
||||
device = H.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.causal:
|
||||
# this seems to cause insertions at the end of the utterance if used with zipformer.
|
||||
pad_len = 30
|
||||
feature_lens += pad_len
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, pad_len),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
|
||||
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
if params.decoding_method == "ctc-greedy-search":
|
||||
hyps = ctc_greedy_search(ctc_output, encoder_out_lens)
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(hyps)
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-greedy-search"
|
||||
return {key: hyps}
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
torch.div(
|
||||
supervisions["start_frame"],
|
||||
params.subsampling_factor,
|
||||
rounding_mode="floor",
|
||||
),
|
||||
torch.div(
|
||||
supervisions["num_frames"],
|
||||
params.subsampling_factor,
|
||||
rounding_mode="floor",
|
||||
),
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
if H is None:
|
||||
assert HLG is not None
|
||||
decoding_graph = HLG
|
||||
else:
|
||||
assert HLG is None
|
||||
assert bpe_model is not None
|
||||
decoding_graph = H
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=ctc_output,
|
||||
decoding_graph=decoding_graph,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.decoding_method == "ctc-decoding":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
||||
# since we are using H, not HLG here.
|
||||
#
|
||||
# token_ids is a lit-of-list of IDs
|
||||
token_ids = get_texts(best_path)
|
||||
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-decoding"
|
||||
return {key: hyps} # note: returns words
|
||||
|
||||
if params.decoding_method == "attention-decoder-rescoring-no-ngram":
|
||||
best_path_dict = rescore_with_attention_decoder_no_ngram(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
attention_decoder=model.attention_decoder,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
ans = dict()
|
||||
for a_scale_str, best_path in best_path_dict.items():
|
||||
# token_ids is a lit-of-list of IDs
|
||||
token_ids = get_texts(best_path)
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
ans[a_scale_str] = hyps
|
||||
return ans
|
||||
|
||||
if params.decoding_method == "nbest-oracle":
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is only slightly worse than that of rescored lattices.
|
||||
best_path = nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
nbest_scale=params.nbest_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa
|
||||
return {key: hyps}
|
||||
|
||||
if params.decoding_method in ["1best", "nbest"]:
|
||||
if params.decoding_method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
key = "no-rescore"
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps} # note: returns BPE tokens
|
||||
|
||||
assert params.decoding_method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder-rescoring-with-ngram",
|
||||
]
|
||||
|
||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
|
||||
if params.decoding_method == "nbest-rescoring":
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
elif params.decoding_method == "whole-lattice-rescoring":
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
elif params.decoding_method == "attention-decoder-rescoring-with-ngram":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=None,
|
||||
)
|
||||
best_path_dict = rescore_with_attention_decoder_with_ngram(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
attention_decoder=model.attention_decoder,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.decoding_method}"
|
||||
|
||||
ans = dict()
|
||||
if best_path_dict is not None:
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
else:
|
||||
ans = None
|
||||
return ans
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
word_table: k2.SymbolTable,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.decoding_method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.decoding_method is ctc-decoding.
|
||||
word_table:
|
||||
It is the word symbol table.
|
||||
G:
|
||||
An LM. It is not None when params.decoding_method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
batch=batch,
|
||||
word_table=word_table,
|
||||
G=G,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_asr_output(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
"""
|
||||
Save text produced by ASR.
|
||||
"""
|
||||
for key, results in results_dict.items():
|
||||
|
||||
recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recogs_filename, texts=results)
|
||||
|
||||
logging.info(f"The transcripts are stored in {recogs_filename}")
|
||||
|
||||
|
||||
def save_wer_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
if params.decoding_method in (
|
||||
"attention-decoder-rescoring-with-ngram",
|
||||
"whole-lattice-rescoring",
|
||||
):
|
||||
# Set it to False since there are too many logs.
|
||||
enable_log = False
|
||||
else:
|
||||
enable_log = True
|
||||
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w", encoding="utf8") as fd:
|
||||
wer = write_error_stats(
|
||||
fd, f"{test_set_name}_{key}", results, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info(f"Wrote detailed error stats to {errs_filename}")
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
|
||||
wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
|
||||
with open(wer_filename, "w", encoding="utf8") as fd:
|
||||
print("settings\tWER", file=fd)
|
||||
for key, val in test_set_wers:
|
||||
print(f"{key}\t{val}", file=fd)
|
||||
|
||||
s = f"\nFor {test_set_name}, WER of different settings are:\n"
|
||||
note = f"\tbest for {test_set_name}"
|
||||
for key, val in test_set_wers:
|
||||
s += f"{key}\t{val}{note}\n"
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriTTSAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
args.lm_dir = Path(args.lm_dir)
|
||||
|
||||
params = get_params()
|
||||
# add decoding params
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
|
||||
# enable AudioCache
|
||||
set_caching_enabled(True) # lhotse
|
||||
|
||||
assert params.decoding_method in (
|
||||
"ctc-greedy-search",
|
||||
"ctc-decoding",
|
||||
"1best",
|
||||
"nbest",
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"nbest-oracle",
|
||||
"attention-decoder-rescoring-no-ngram",
|
||||
"attention-decoder-rescoring-with-ngram",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
"," not in params.chunk_size
|
||||
), "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"_chunk-{params.chunk_size}"
|
||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "_use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
|
||||
params.vocab_size = num_classes
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = 0
|
||||
params.eos_id = 1
|
||||
params.sos_id = 1
|
||||
|
||||
if params.decoding_method in [
|
||||
"ctc-greedy-search",
|
||||
"ctc-decoding",
|
||||
"attention-decoder-rescoring-no-ngram",
|
||||
]:
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||
else:
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
HLG.scores *= params.hlg_scale
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.decoding_method in (
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder-rescoring-with-ngram",
|
||||
):
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
# G.aux_labels is not needed in later computations, so
|
||||
# remove it here.
|
||||
del G.aux_labels
|
||||
# CAUTION: The following line is crucial.
|
||||
# Arcs entering the back-off state have label equal to #0.
|
||||
# We have to change it to 0 here.
|
||||
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set G.properties to None
|
||||
G.__dict__["_properties"] = None
|
||||
G = k2.Fsa.from_fsas([G]).to(device)
|
||||
G = k2.arc_sort(G)
|
||||
# Save a dummy value so that it can be loaded in C++.
|
||||
# See https://github.com/pytorch/pytorch/issues/67902
|
||||
# for why we need to do this.
|
||||
G.dummy = 1
|
||||
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.decoding_method in [
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder-rescoring-with-ngram",
|
||||
]:
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
G = G.to(device)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
else:
|
||||
G = None
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriTTSAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
word_table=lexicon.word_table,
|
||||
G=G,
|
||||
)
|
||||
|
||||
save_asr_output(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
if not params.skip_scoring:
|
||||
save_wer_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1085
egs/libritts/ASR/zipformer/decode.py
Executable file
1085
egs/libritts/ASR/zipformer/decode.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/libritts/ASR/zipformer/decode_stream.py
Symbolic link
1
egs/libritts/ASR/zipformer/decode_stream.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/decode_stream.py
|
1
egs/libritts/ASR/zipformer/encoder_interface.py
Symbolic link
1
egs/libritts/ASR/zipformer/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/transducer_stateless/encoder_interface.py
|
1
egs/libritts/ASR/zipformer/export-onnx-ctc.py
Symbolic link
1
egs/libritts/ASR/zipformer/export-onnx-ctc.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/export-onnx-ctc.py
|
1
egs/libritts/ASR/zipformer/export-onnx-streaming-ctc.py
Symbolic link
1
egs/libritts/ASR/zipformer/export-onnx-streaming-ctc.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
|
1
egs/libritts/ASR/zipformer/export-onnx-streaming.py
Symbolic link
1
egs/libritts/ASR/zipformer/export-onnx-streaming.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/export-onnx-streaming.py
|
1
egs/libritts/ASR/zipformer/export-onnx.py
Symbolic link
1
egs/libritts/ASR/zipformer/export-onnx.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/export-onnx.py
|
1
egs/libritts/ASR/zipformer/export.py
Symbolic link
1
egs/libritts/ASR/zipformer/export.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/export.py
|
1
egs/libritts/ASR/zipformer/generate_averaged_model.py
Symbolic link
1
egs/libritts/ASR/zipformer/generate_averaged_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/generate_averaged_model.py
|
1
egs/libritts/ASR/zipformer/jit_pretrained.py
Symbolic link
1
egs/libritts/ASR/zipformer/jit_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/jit_pretrained.py
|
1
egs/libritts/ASR/zipformer/jit_pretrained_ctc.py
Symbolic link
1
egs/libritts/ASR/zipformer/jit_pretrained_ctc.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py
|
1
egs/libritts/ASR/zipformer/jit_pretrained_streaming.py
Symbolic link
1
egs/libritts/ASR/zipformer/jit_pretrained_streaming.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py
|
1
egs/libritts/ASR/zipformer/joiner.py
Symbolic link
1
egs/libritts/ASR/zipformer/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/joiner.py
|
1
egs/libritts/ASR/zipformer/label_smoothing.py
Symbolic link
1
egs/libritts/ASR/zipformer/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/label_smoothing.py
|
1
egs/libritts/ASR/zipformer/model.py
Symbolic link
1
egs/libritts/ASR/zipformer/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/model.py
|
1
egs/libritts/ASR/zipformer/my_profile.py
Symbolic link
1
egs/libritts/ASR/zipformer/my_profile.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/my_profile.py
|
1
egs/libritts/ASR/zipformer/onnx_check.py
Symbolic link
1
egs/libritts/ASR/zipformer/onnx_check.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_check.py
|
324
egs/libritts/ASR/zipformer/onnx_decode.py
Executable file
324
egs/libritts/ASR/zipformer/onnx_decode.py
Executable file
@ -0,0 +1,324 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads ONNX exported models and uses them to decode the test sets.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--causal False
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
2. Run this file
|
||||
|
||||
./zipformer/onnx_decode.py \
|
||||
--exp-dir $repo/exp \
|
||||
--max-duration 600 \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriTTSAsrDataModule
|
||||
from k2 import SymbolTable
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="Valid values are greedy_search and modified_beam_search",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
model: OnnxModel, token_table: SymbolTable, batch: dict
|
||||
) -> List[List[str]]:
|
||||
"""Decode one batch and return the result.
|
||||
Currently it only greedy_search is supported.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The neural model.
|
||||
token_table:
|
||||
The token table.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
|
||||
|
||||
encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
|
||||
)
|
||||
|
||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||
text = ""
|
||||
for i in token_ids:
|
||||
text += token_table[i]
|
||||
return text.replace("▁", " ").strip()
|
||||
|
||||
hyps = [token_ids_to_words(h).split() for h in hyps]
|
||||
return hyps
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
model: nn.Module,
|
||||
token_table: SymbolTable,
|
||||
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
model:
|
||||
The neural model.
|
||||
token_table:
|
||||
The token table.
|
||||
|
||||
Returns:
|
||||
- A list of tuples. Each tuple contains three elements:
|
||||
- cut_id,
|
||||
- reference transcript,
|
||||
- predicted result.
|
||||
- The total duration (in seconds) of the dataset.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
log_interval = 10
|
||||
total_duration = 0
|
||||
|
||||
results = []
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
|
||||
|
||||
hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
|
||||
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results.extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
|
||||
return results, total_duration
|
||||
|
||||
|
||||
def save_results(
|
||||
res_dir: Path,
|
||||
test_set_name: str,
|
||||
results: List[Tuple[str, List[str], List[str]]],
|
||||
):
|
||||
recog_path = res_dir / f"recogs-{test_set_name}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = res_dir / f"errs-{test_set_name}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("WER", file=f)
|
||||
print(wer, file=f)
|
||||
|
||||
s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriTTSAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert (
|
||||
args.decoding_method == "greedy_search"
|
||||
), "Only supports greedy_search currently."
|
||||
res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
|
||||
|
||||
setup_logger(f"{res_dir}/log-decode")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
token_table = SymbolTable.from_file(args.tokens)
|
||||
|
||||
logging.info(vars(args))
|
||||
|
||||
logging.info("About to create model")
|
||||
model = OnnxModel(
|
||||
encoder_model_filename=args.encoder_model_filename,
|
||||
decoder_model_filename=args.decoder_model_filename,
|
||||
joiner_model_filename=args.joiner_model_filename,
|
||||
)
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriTTSAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
start_time = time.time()
|
||||
results, total_duration = decode_dataset(
|
||||
dl=test_dl, model=model, token_table=token_table
|
||||
)
|
||||
end_time = time.time()
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / total_duration
|
||||
|
||||
logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
logging.info(f"Wave duration: {total_duration:.3f} s")
|
||||
logging.info(
|
||||
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||
)
|
||||
|
||||
save_results(res_dir=res_dir, test_set_name=test_set, results=results)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/libritts/ASR/zipformer/onnx_pretrained-streaming-ctc.py
Symbolic link
1
egs/libritts/ASR/zipformer/onnx_pretrained-streaming-ctc.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py
|
1
egs/libritts/ASR/zipformer/onnx_pretrained-streaming.py
Symbolic link
1
egs/libritts/ASR/zipformer/onnx_pretrained-streaming.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py
|
1
egs/libritts/ASR/zipformer/onnx_pretrained.py
Symbolic link
1
egs/libritts/ASR/zipformer/onnx_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_pretrained.py
|
1
egs/libritts/ASR/zipformer/onnx_pretrained_ctc.py
Symbolic link
1
egs/libritts/ASR/zipformer/onnx_pretrained_ctc.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_pretrained_ctc.py
|
1
egs/libritts/ASR/zipformer/onnx_pretrained_ctc_H.py
Symbolic link
1
egs/libritts/ASR/zipformer/onnx_pretrained_ctc_H.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py
|
1
egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HL.py
Symbolic link
1
egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HL.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py
|
1
egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG.py
Symbolic link
1
egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py
|
1
egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py
Symbolic link
1
egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py
|
1
egs/libritts/ASR/zipformer/optim.py
Symbolic link
1
egs/libritts/ASR/zipformer/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/optim.py
|
1
egs/libritts/ASR/zipformer/pretrained.py
Symbolic link
1
egs/libritts/ASR/zipformer/pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/pretrained.py
|
1
egs/libritts/ASR/zipformer/pretrained_ctc.py
Symbolic link
1
egs/libritts/ASR/zipformer/pretrained_ctc.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/pretrained_ctc.py
|
1
egs/libritts/ASR/zipformer/scaling.py
Symbolic link
1
egs/libritts/ASR/zipformer/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/scaling.py
|
1
egs/libritts/ASR/zipformer/scaling_converter.py
Symbolic link
1
egs/libritts/ASR/zipformer/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/scaling_converter.py
|
1
egs/libritts/ASR/zipformer/streaming_beam_search.py
Symbolic link
1
egs/libritts/ASR/zipformer/streaming_beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/streaming_beam_search.py
|
904
egs/libritts/ASR/zipformer/streaming_decode.py
Executable file
904
egs/libritts/ASR/zipformer/streaming_decode.py
Executable file
@ -0,0 +1,904 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
|
||||
# Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./zipformer/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--causal 1 \
|
||||
--chunk-size 32 \
|
||||
--left-context-frames 256 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--decoding-method greedy_search \
|
||||
--num-decode-streams 2000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from asr_datamodule import LibriTTSAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet, set_caching_enabled
|
||||
from streaming_beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
make_pad_mask,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--label",
|
||||
type=str,
|
||||
default="",
|
||||
help="""Extra label of the decoding run.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Supported decoding methods are:
|
||||
greedy_search
|
||||
modified_beam_search
|
||||
fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_active_paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=32,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decode-streams",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The number of streams that can be decoded parallel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip-scoring",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Skip scoring, but still save the ASR output (for eval sets).""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_init_states(
|
||||
model: nn.Module,
|
||||
batch_size: int = 1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
|
||||
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
||||
states[-2] is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
states[-1] is processed_lens of shape (batch,), which records the number
|
||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||
"""
|
||||
states = model.encoder.get_init_states(batch_size, device)
|
||||
|
||||
embed_states = model.encoder_embed.get_init_states(batch_size, device)
|
||||
states.append(embed_states)
|
||||
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
states.append(processed_lens)
|
||||
|
||||
return states
|
||||
|
||||
|
||||
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
|
||||
"""Stack list of zipformer states that correspond to separate utterances
|
||||
into a single emformer state, so that it can be used as an input for
|
||||
zipformer when those utterances are formed into a batch.
|
||||
|
||||
Args:
|
||||
state_list:
|
||||
Each element in state_list corresponding to the internal state
|
||||
of the zipformer model for a single utterance. For element-n,
|
||||
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
|
||||
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
|
||||
cached_val2, cached_conv1, cached_conv2).
|
||||
state_list[n][-2] is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
state_list[n][-1] is processed_lens of shape (batch,), which records the number
|
||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||
|
||||
Note:
|
||||
It is the inverse of :func:`unstack_states`.
|
||||
"""
|
||||
batch_size = len(state_list)
|
||||
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
|
||||
tot_num_layers = (len(state_list[0]) - 2) // 6
|
||||
|
||||
batch_states = []
|
||||
for layer in range(tot_num_layers):
|
||||
layer_offset = layer * 6
|
||||
# cached_key: (left_context_len, batch_size, key_dim)
|
||||
cached_key = torch.cat(
|
||||
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
|
||||
)
|
||||
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
||||
cached_nonlin_attn = torch.cat(
|
||||
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
|
||||
)
|
||||
# cached_val1: (left_context_len, batch_size, value_dim)
|
||||
cached_val1 = torch.cat(
|
||||
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
|
||||
)
|
||||
# cached_val2: (left_context_len, batch_size, value_dim)
|
||||
cached_val2 = torch.cat(
|
||||
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
|
||||
)
|
||||
# cached_conv1: (#batch, channels, left_pad)
|
||||
cached_conv1 = torch.cat(
|
||||
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
|
||||
)
|
||||
# cached_conv2: (#batch, channels, left_pad)
|
||||
cached_conv2 = torch.cat(
|
||||
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
|
||||
)
|
||||
batch_states += [
|
||||
cached_key,
|
||||
cached_nonlin_attn,
|
||||
cached_val1,
|
||||
cached_val2,
|
||||
cached_conv1,
|
||||
cached_conv2,
|
||||
]
|
||||
|
||||
cached_embed_left_pad = torch.cat(
|
||||
[state_list[i][-2] for i in range(batch_size)], dim=0
|
||||
)
|
||||
batch_states.append(cached_embed_left_pad)
|
||||
|
||||
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
|
||||
batch_states.append(processed_lens)
|
||||
|
||||
return batch_states
|
||||
|
||||
|
||||
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
|
||||
"""Unstack the zipformer state corresponding to a batch of utterances
|
||||
into a list of states, where the i-th entry is the state from the i-th
|
||||
utterance in the batch.
|
||||
|
||||
Note:
|
||||
It is the inverse of :func:`stack_states`.
|
||||
|
||||
Args:
|
||||
batch_states: A list of cached tensors of all encoder layers. For layer-i,
|
||||
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
|
||||
cached_conv1, cached_conv2).
|
||||
state_list[-2] is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
states[-1] is processed_lens of shape (batch,), which records the number
|
||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||
|
||||
Returns:
|
||||
state_list: A list of list. Each element in state_list corresponding to the internal state
|
||||
of the zipformer model for a single utterance.
|
||||
"""
|
||||
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
|
||||
tot_num_layers = (len(batch_states) - 2) // 6
|
||||
|
||||
processed_lens = batch_states[-1]
|
||||
batch_size = processed_lens.shape[0]
|
||||
|
||||
state_list = [[] for _ in range(batch_size)]
|
||||
|
||||
for layer in range(tot_num_layers):
|
||||
layer_offset = layer * 6
|
||||
# cached_key: (left_context_len, batch_size, key_dim)
|
||||
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
|
||||
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
||||
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
|
||||
chunks=batch_size, dim=1
|
||||
)
|
||||
# cached_val1: (left_context_len, batch_size, value_dim)
|
||||
cached_val1_list = batch_states[layer_offset + 2].chunk(
|
||||
chunks=batch_size, dim=1
|
||||
)
|
||||
# cached_val2: (left_context_len, batch_size, value_dim)
|
||||
cached_val2_list = batch_states[layer_offset + 3].chunk(
|
||||
chunks=batch_size, dim=1
|
||||
)
|
||||
# cached_conv1: (#batch, channels, left_pad)
|
||||
cached_conv1_list = batch_states[layer_offset + 4].chunk(
|
||||
chunks=batch_size, dim=0
|
||||
)
|
||||
# cached_conv2: (#batch, channels, left_pad)
|
||||
cached_conv2_list = batch_states[layer_offset + 5].chunk(
|
||||
chunks=batch_size, dim=0
|
||||
)
|
||||
for i in range(batch_size):
|
||||
state_list[i] += [
|
||||
cached_key_list[i],
|
||||
cached_nonlin_attn_list[i],
|
||||
cached_val1_list[i],
|
||||
cached_val2_list[i],
|
||||
cached_conv1_list[i],
|
||||
cached_conv2_list[i],
|
||||
]
|
||||
|
||||
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
|
||||
for i in range(batch_size):
|
||||
state_list[i].append(cached_embed_left_pad_list[i])
|
||||
|
||||
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
|
||||
for i in range(batch_size):
|
||||
state_list[i].append(processed_lens_list[i])
|
||||
|
||||
return state_list
|
||||
|
||||
|
||||
def streaming_forward(
|
||||
features: Tensor,
|
||||
feature_lens: Tensor,
|
||||
model: nn.Module,
|
||||
states: List[Tensor],
|
||||
chunk_size: int,
|
||||
left_context_len: int,
|
||||
) -> Tuple[Tensor, Tensor, List[Tensor]]:
|
||||
"""
|
||||
Returns encoder outputs, output lengths, and updated states.
|
||||
"""
|
||||
cached_embed_left_pad = states[-2]
|
||||
(
|
||||
x,
|
||||
x_lens,
|
||||
new_cached_embed_left_pad,
|
||||
) = model.encoder_embed.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
cached_left_pad=cached_embed_left_pad,
|
||||
)
|
||||
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
|
||||
# processed_mask is used to mask out initial states
|
||||
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
||||
x.size(0), left_context_len
|
||||
)
|
||||
processed_lens = states[-1] # (batch,)
|
||||
# (batch, left_context_size)
|
||||
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
|
||||
# Update processed lengths
|
||||
new_processed_lens = processed_lens + x_lens
|
||||
|
||||
# (batch, left_context_size + chunk_size)
|
||||
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
encoder_states = states[:-2]
|
||||
(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
new_encoder_states,
|
||||
) = model.encoder.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=encoder_states,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
new_states = new_encoder_states + [
|
||||
new_cached_embed_left_pad,
|
||||
new_processed_lens,
|
||||
]
|
||||
return encoder_out, encoder_out_lens, new_states
|
||||
|
||||
|
||||
def decode_one_chunk(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
decode_streams: List[DecodeStream],
|
||||
) -> List[int]:
|
||||
"""Decode one chunk frames of features for each decode_streams and
|
||||
return the indexes of finished streams in a List.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
decode_streams:
|
||||
A List of DecodeStream, each belonging to a utterance.
|
||||
Returns:
|
||||
Return a List containing which DecodeStreams are finished.
|
||||
"""
|
||||
device = model.device
|
||||
chunk_size = int(params.chunk_size)
|
||||
left_context_len = int(params.left_context_frames)
|
||||
|
||||
features = []
|
||||
feature_lens = []
|
||||
states = []
|
||||
processed_lens = [] # Used in fast-beam-search
|
||||
|
||||
for stream in decode_streams:
|
||||
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_lens.append(stream.done_frames)
|
||||
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
||||
|
||||
# Make sure the length after encoder_embed is at least 1.
|
||||
# The encoder_embed subsample features (T - 7) // 2
|
||||
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||
tail_length = chunk_size * 2 + 7 + 2 * 3
|
||||
if features.size(1) < tail_length:
|
||||
pad_length = tail_length - features.size(1)
|
||||
feature_lens += pad_length
|
||||
features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, pad_length),
|
||||
mode="constant",
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
states = stack_states(states)
|
||||
|
||||
encoder_out, encoder_out_lens, new_states = streaming_forward(
|
||||
features=features,
|
||||
feature_lens=feature_lens,
|
||||
model=model,
|
||||
states=states,
|
||||
chunk_size=chunk_size,
|
||||
left_context_len=left_context_len,
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
processed_lens = processed_lens + encoder_out_lens
|
||||
fast_beam_search_one_best(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
processed_lens=processed_lens,
|
||||
streams=decode_streams,
|
||||
beam=params.beam,
|
||||
max_states=params.max_states,
|
||||
max_contexts=params.max_contexts,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
modified_beam_search(
|
||||
model=model,
|
||||
streams=decode_streams,
|
||||
encoder_out=encoder_out,
|
||||
num_active_paths=params.num_active_paths,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
|
||||
states = unstack_states(new_states)
|
||||
|
||||
finished_streams = []
|
||||
for i in range(len(decode_streams)):
|
||||
decode_streams[i].states = states[i]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
if decode_streams[i].done:
|
||||
finished_streams.append(i)
|
||||
|
||||
return finished_streams
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
cuts: CutSet,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
cuts:
|
||||
Lhotse Cutset containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
device = model.device
|
||||
|
||||
opts = FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
log_interval = 100
|
||||
|
||||
decode_results = []
|
||||
# Contain decode streams currently running.
|
||||
decode_streams = []
|
||||
for num, cut in enumerate(cuts):
|
||||
# each utterance has a DecodeStream.
|
||||
initial_states = get_init_states(model=model, batch_size=1, device=device)
|
||||
decode_stream = DecodeStream(
|
||||
params=params,
|
||||
cut_id=cut.id,
|
||||
initial_states=initial_states,
|
||||
decoding_graph=decoding_graph,
|
||||
device=device,
|
||||
)
|
||||
|
||||
audio: np.ndarray = cut.load_audio()
|
||||
# audio.shape: (1, num_samples)
|
||||
assert len(audio.shape) == 2
|
||||
assert audio.shape[0] == 1, "Should be single channel"
|
||||
assert audio.dtype == np.float32, audio.dtype
|
||||
|
||||
# The trained model is using normalized samples
|
||||
# - this is to avoid sending [-32k,+32k] signal in...
|
||||
# - some lhotse AudioTransform classes can make the signal
|
||||
# be out of range [-1, 1], hence the tolerance 10
|
||||
assert (
|
||||
np.abs(audio).max() <= 10
|
||||
), "Should be normalized to [-1, 1], 10 for tolerance..."
|
||||
|
||||
samples = torch.from_numpy(audio).squeeze(0)
|
||||
|
||||
fbank = Fbank(opts)
|
||||
feature = fbank(samples.to(device))
|
||||
decode_stream.set_features(feature, tail_pad_len=30)
|
||||
decode_stream.ground_truth = cut.supervisions[0].text
|
||||
|
||||
decode_streams.append(decode_stream)
|
||||
|
||||
while len(decode_streams) >= params.num_decode_streams:
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
if num % log_interval == 0:
|
||||
logging.info(f"Cuts processed until now is {num}.")
|
||||
|
||||
# decode final chunks of last sequences
|
||||
while len(decode_streams):
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
key = "greedy_search"
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
key = (
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
def save_asr_output(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
"""
|
||||
Save text produced by ASR.
|
||||
"""
|
||||
for key, results in results_dict.items():
|
||||
recogs_filename = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recogs_filename, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recogs_filename}")
|
||||
|
||||
|
||||
def save_wer_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
"""
|
||||
Save WER and per-utterance word alignments.
|
||||
"""
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w", encoding="utf8") as fd:
|
||||
wer = write_error_stats(
|
||||
fd, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info(f"Wrote detailed error stats to {errs_filename}")
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
|
||||
wer_filename = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(wer_filename, "w", encoding="utf8") as fd:
|
||||
print("settings\tWER", file=fd)
|
||||
for key, val in test_set_wers:
|
||||
print(f"{key}\t{val}", file=fd)
|
||||
|
||||
s = f"\nFor {test_set_name}, WER of different settings are:\n"
|
||||
note = f"\tbest for {test_set_name}"
|
||||
for key, val in test_set_wers:
|
||||
s += f"{key}\t{val}{note}\n"
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriTTSAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
# enable AudioCache
|
||||
set_caching_enabled(True) # lhotse
|
||||
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
assert params.causal, params.causal
|
||||
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"_chunk-{params.chunk_size}"
|
||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||
|
||||
# for fast_beam_search
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
params.suffix += f"_beam-{params.beam}"
|
||||
params.suffix += f"_max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"_max-states-{params.max_states}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
if params.label:
|
||||
params.suffix += f"-{params.label}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
decoding_graph = None
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriTTSAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_cuts = [test_clean_cuts, test_other_cuts]
|
||||
|
||||
for test_set, test_cut in zip(test_sets, test_cuts):
|
||||
results_dict = decode_dataset(
|
||||
cuts=test_cut,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_asr_output(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
if not params.skip_scoring:
|
||||
save_wer_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/libritts/ASR/zipformer/subsampling.py
Symbolic link
1
egs/libritts/ASR/zipformer/subsampling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/subsampling.py
|
1511
egs/libritts/ASR/zipformer/train.py
Executable file
1511
egs/libritts/ASR/zipformer/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/libritts/ASR/zipformer/zipformer.py
Symbolic link
1
egs/libritts/ASR/zipformer/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/zipformer.py
|
161
egs/libritts/CODEC/encodec/binary.py
Normal file
161
egs/libritts/CODEC/encodec/binary.py
Normal file
@ -0,0 +1,161 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import struct
|
||||
from typing import IO, Any, List, Optional
|
||||
|
||||
# format is `ECDC` magic code, followed by the header size as uint32.
|
||||
# Then an uint8 indicates the protocol version (0.)
|
||||
# The header is then provided as json and should contain all required
|
||||
# informations for decoding. A raw stream of bytes is then provided
|
||||
# and should be interpretable using the json header.
|
||||
_encodec_header_struct = struct.Struct("!4sBI")
|
||||
_ENCODEC_MAGIC = b"ECDC"
|
||||
|
||||
|
||||
def write_ecdc_header(fo: IO[bytes], metadata: Any):
|
||||
meta_dumped = json.dumps(metadata).encode("utf-8")
|
||||
version = 0
|
||||
header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, len(meta_dumped))
|
||||
fo.write(header)
|
||||
fo.write(meta_dumped)
|
||||
fo.flush()
|
||||
|
||||
|
||||
def _read_exactly(fo: IO[bytes], size: int) -> bytes:
|
||||
buf = b""
|
||||
while len(buf) < size:
|
||||
new_buf = fo.read(size)
|
||||
if not new_buf:
|
||||
raise EOFError(
|
||||
"Impossible to read enough data from the stream, "
|
||||
f"{size} bytes remaining."
|
||||
)
|
||||
buf += new_buf
|
||||
size -= len(new_buf)
|
||||
return buf
|
||||
|
||||
|
||||
def read_ecdc_header(fo: IO[bytes]):
|
||||
header_bytes = _read_exactly(fo, _encodec_header_struct.size)
|
||||
magic, version, meta_size = _encodec_header_struct.unpack(header_bytes)
|
||||
if magic != _ENCODEC_MAGIC:
|
||||
raise ValueError("File is not in ECDC format.")
|
||||
if version != 0:
|
||||
raise ValueError("Version not supported.")
|
||||
meta_bytes = _read_exactly(fo, meta_size)
|
||||
return json.loads(meta_bytes.decode("utf-8"))
|
||||
|
||||
|
||||
class BitPacker:
|
||||
"""Simple bit packer to handle ints with a non standard width, e.g. 10 bits.
|
||||
Note that for some bandwidth (1.5, 3), the codebook representation
|
||||
will not cover an integer number of bytes.
|
||||
|
||||
Args:
|
||||
bits (int): number of bits per value that will be pushed.
|
||||
fo (IO[bytes]): file-object to push the bytes to.
|
||||
"""
|
||||
|
||||
def __init__(self, bits: int, fo: IO[bytes]):
|
||||
self._current_value = 0
|
||||
self._current_bits = 0
|
||||
self.bits = bits
|
||||
self.fo = fo
|
||||
|
||||
def push(self, value: int):
|
||||
"""Push a new value to the stream. This will immediately
|
||||
write as many uint8 as possible to the underlying file-object."""
|
||||
self._current_value += value << self._current_bits
|
||||
self._current_bits += self.bits
|
||||
while self._current_bits >= 8:
|
||||
lower_8bits = self._current_value & 0xFF
|
||||
self._current_bits -= 8
|
||||
self._current_value >>= 8
|
||||
self.fo.write(bytes([lower_8bits]))
|
||||
|
||||
def flush(self):
|
||||
"""Flushes the remaining partial uint8, call this at the end
|
||||
of the stream to encode."""
|
||||
if self._current_bits:
|
||||
self.fo.write(bytes([self._current_value]))
|
||||
self._current_value = 0
|
||||
self._current_bits = 0
|
||||
self.fo.flush()
|
||||
|
||||
|
||||
class BitUnpacker:
|
||||
"""BitUnpacker does the opposite of `BitPacker`.
|
||||
|
||||
Args:
|
||||
bits (int): number of bits of the values to decode.
|
||||
fo (IO[bytes]): file-object to push the bytes to.
|
||||
"""
|
||||
|
||||
def __init__(self, bits: int, fo: IO[bytes]):
|
||||
self.bits = bits
|
||||
self.fo = fo
|
||||
self._mask = (1 << bits) - 1
|
||||
self._current_value = 0
|
||||
self._current_bits = 0
|
||||
|
||||
def pull(self) -> Optional[int]:
|
||||
"""
|
||||
Pull a single value from the stream, potentially reading some
|
||||
extra bytes from the underlying file-object.
|
||||
Returns `None` when reaching the end of the stream.
|
||||
"""
|
||||
while self._current_bits < self.bits:
|
||||
buf = self.fo.read(1)
|
||||
if not buf:
|
||||
return None
|
||||
character = buf[0]
|
||||
self._current_value += character << self._current_bits
|
||||
self._current_bits += 8
|
||||
|
||||
out = self._current_value & self._mask
|
||||
self._current_value >>= self.bits
|
||||
self._current_bits -= self.bits
|
||||
return out
|
||||
|
||||
|
||||
def test():
|
||||
import torch
|
||||
|
||||
torch.manual_seed(1234)
|
||||
for rep in range(4):
|
||||
length: int = torch.randint(10, 2_000, (1,)).item()
|
||||
bits: int = torch.randint(1, 16, (1,)).item()
|
||||
tokens: List[int] = torch.randint(2**bits, (length,)).tolist()
|
||||
rebuilt: List[int] = []
|
||||
buf = io.BytesIO()
|
||||
packer = BitPacker(bits, buf)
|
||||
for token in tokens:
|
||||
packer.push(token)
|
||||
packer.flush()
|
||||
buf.seek(0)
|
||||
unpacker = BitUnpacker(bits, buf)
|
||||
while True:
|
||||
value = unpacker.pull()
|
||||
if value is None:
|
||||
break
|
||||
rebuilt.append(value)
|
||||
assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens))
|
||||
# The flushing mechanism might lead to "ghost" values at the end of the stream.
|
||||
assert len(rebuilt) <= len(tokens) + 8 // bits, (
|
||||
len(rebuilt),
|
||||
len(tokens),
|
||||
bits,
|
||||
)
|
||||
for idx, (a, b) in enumerate(zip(tokens, rebuilt)):
|
||||
assert a == b, (idx, a, b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
271
egs/libritts/CODEC/encodec/codec_datamodule.py
Normal file
271
egs/libritts/CODEC/encodec/codec_datamodule.py
Normal file
@ -0,0 +1,271 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||
# Zengwei Yao,
|
||||
# Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriTTSCodecDataModule:
|
||||
"""
|
||||
DataModule for tts experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="Codec data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/spectrogram"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=False,
|
||||
return_spk_ids=False,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create valid dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.info("About to create test dataset")
|
||||
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
test_sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=test_sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get validation cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
|
117
egs/libritts/CODEC/encodec/discriminators.py
Normal file
117
egs/libritts/CODEC/encodec/discriminators.py
Normal file
@ -0,0 +1,117 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from models.discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT
|
||||
from torch.nn import AvgPool1d
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(2),
|
||||
DiscriminatorP(3),
|
||||
DiscriminatorP(5),
|
||||
DiscriminatorP(7),
|
||||
DiscriminatorP(11),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super(MultiScaleDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorS(),
|
||||
DiscriminatorS(),
|
||||
DiscriminatorS(),
|
||||
]
|
||||
)
|
||||
self.meanpools = nn.ModuleList(
|
||||
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
|
||||
)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
if i != 0:
|
||||
y = self.meanpools[i - 1](y)
|
||||
y_hat = self.meanpools[i - 1](y_hat)
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class MultiScaleSTFTDiscriminator(nn.Module):
|
||||
"""Multi-Scale STFT (MS-STFT) discriminator.
|
||||
Args:
|
||||
filters (int): Number of filters in convolutions
|
||||
in_channels (int): Number of input channels. Default: 1
|
||||
out_channels (int): Number of output channels. Default: 1
|
||||
n_ffts (Sequence[int]): Size of FFT for each scale
|
||||
hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
|
||||
win_lengths (Sequence[int]): Window size for each scale
|
||||
**kwargs: additional args for STFTDiscriminator
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filters: int,
|
||||
in_channels: int = 1,
|
||||
out_channels: int = 1,
|
||||
n_ffts: List[int] = [1024, 2048, 512, 256, 128],
|
||||
hop_lengths: List[int] = [256, 512, 128, 64, 32],
|
||||
win_lengths: List[int] = [1024, 2048, 512, 256, 128],
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorSTFT(
|
||||
filters,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
n_fft=n_ffts[i],
|
||||
win_length=win_lengths[i],
|
||||
hop_length=hop_lengths[i],
|
||||
**kwargs
|
||||
)
|
||||
for i in range(len(n_ffts))
|
||||
]
|
||||
)
|
||||
self.num_discriminators = len(self.discriminators)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
logits = []
|
||||
fmaps = []
|
||||
for disc in self.discriminators:
|
||||
logit, fmap = disc(x)
|
||||
logits.append(logit)
|
||||
fmaps.append(fmap)
|
||||
return logits, fmaps
|
261
egs/libritts/CODEC/encodec/encodec.py
Normal file
261
egs/libritts/CODEC/encodec/encodec.py
Normal file
@ -0,0 +1,261 @@
|
||||
import math
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loss import loss_dis, loss_g
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
|
||||
class Encodec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int,
|
||||
target_bandwidths: List[float],
|
||||
params: dict,
|
||||
encoder: nn.Module,
|
||||
quantizer: nn.Module,
|
||||
decoder: nn.Module,
|
||||
multi_scale_discriminator: nn.Module,
|
||||
multi_period_discriminator: nn.Module,
|
||||
multi_scale_stft_discriminator: nn.Module,
|
||||
cache_generator_outputs: bool = True,
|
||||
):
|
||||
super(Encodec, self).__init__()
|
||||
|
||||
self.params = params
|
||||
|
||||
# setup the generator
|
||||
self.sample_rate = sample_rate
|
||||
self.encoder = encoder
|
||||
self.quantizer = quantizer
|
||||
self.decoder = decoder
|
||||
|
||||
self.ratios = encoder.ratios
|
||||
self.hop_length = np.prod(self.ratios)
|
||||
self.frame_rate = math.ceil(self.sample_rate / np.prod(self.ratios))
|
||||
self.target_bandwidths = target_bandwidths
|
||||
|
||||
# discriminators
|
||||
self.multi_scale_discriminator = multi_scale_discriminator
|
||||
self.multi_period_discriminator = multi_period_discriminator
|
||||
self.multi_scale_stft_discriminator = multi_scale_stft_discriminator
|
||||
|
||||
# cache
|
||||
self.cache_generator_outputs = cache_generator_outputs
|
||||
self._cache = None
|
||||
|
||||
def _forward_generator(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
global_step: int,
|
||||
return_sample: bool = False,
|
||||
):
|
||||
"""Perform generator forward.
|
||||
|
||||
Args:
|
||||
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||
speech_lengths (Tensor): Speech length tensor (B,).
|
||||
global_step (int): Global step.
|
||||
return_sample (bool): Return the generator output.
|
||||
|
||||
Returns:
|
||||
* loss (Tensor): Loss scalar tensor.
|
||||
* stats (Dict[str, float]): Statistics to be monitored.
|
||||
"""
|
||||
# setup
|
||||
speech = speech.unsqueeze(1)
|
||||
|
||||
# calculate generator outputs
|
||||
reuse_cache = True
|
||||
if not self.cache_generator_outputs or self._cache is None:
|
||||
reuse_cache = False
|
||||
e = self.encoder(speech)
|
||||
bw = random.choice(self.target_bandwidths)
|
||||
quantized, codes, bandwidth, commit_loss = self.quantizer(
|
||||
e, self.frame_rate, bw
|
||||
)
|
||||
speech_hat = self.decoder(quantized)
|
||||
else:
|
||||
speech_hat = self._cache
|
||||
|
||||
# store cache
|
||||
if self.training and self.cache_generator_outputs and not reuse_cache:
|
||||
self._cache = speech_hat
|
||||
|
||||
# calculate discriminator outputs
|
||||
y_hat, fmap_hat = self.multi_scale_stft_discriminator(speech_hat.contiguous())
|
||||
with torch.no_grad():
|
||||
# do not store discriminator gradient in generator turn
|
||||
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
|
||||
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
||||
speech.contiguous(),
|
||||
speech_hat.contiguous(),
|
||||
)
|
||||
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
||||
speech.contiguous(),
|
||||
speech_hat.contiguous(),
|
||||
)
|
||||
|
||||
# calculate losses
|
||||
with autocast(enabled=False):
|
||||
loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
|
||||
commit_loss,
|
||||
speech,
|
||||
speech_hat,
|
||||
fmap,
|
||||
fmap_hat,
|
||||
y,
|
||||
y_hat,
|
||||
global_step,
|
||||
y_p,
|
||||
y_p_hat,
|
||||
y_s,
|
||||
y_s_hat,
|
||||
fmap_p,
|
||||
fmap_p_hat,
|
||||
fmap_s,
|
||||
fmap_s_hat,
|
||||
args=self.params,
|
||||
)
|
||||
|
||||
stats = dict(
|
||||
generator_loss=loss.item(),
|
||||
generator_reconstruction_loss=rec_loss.item(),
|
||||
generator_feature_loss=feat_loss.item(),
|
||||
generator_adv_loss=adv_loss.item(),
|
||||
generator_commit_loss=commit_loss.item(),
|
||||
d_weight=d_weight.item(),
|
||||
)
|
||||
|
||||
if return_sample:
|
||||
stats["returned_sample"] = (
|
||||
speech_hat[0].data.cpu().numpy(),
|
||||
speech[0].data.cpu().numpy(),
|
||||
fmap_hat[0][0].data.cpu().numpy(),
|
||||
fmap[0][0].data.cpu().numpy(),
|
||||
)
|
||||
|
||||
# reset cache
|
||||
if reuse_cache or not self.training:
|
||||
self._cache = None
|
||||
|
||||
return loss, stats
|
||||
|
||||
def _forward_discriminator(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
global_step: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||
speech_lengths (Tensor): Speech length tensor (B,).
|
||||
global_step (int): Global step.
|
||||
|
||||
Returns:
|
||||
* loss (Tensor): Loss scalar tensor.
|
||||
* stats (Dict[str, float]): Statistics to be monitored.
|
||||
"""
|
||||
# setup
|
||||
speech = speech.unsqueeze(1)
|
||||
|
||||
# calculate generator outputs
|
||||
reuse_cache = True
|
||||
if not self.cache_generator_outputs or self._cache is None:
|
||||
reuse_cache = False
|
||||
e = self.encoder(speech)
|
||||
bw = random.choice(self.target_bandwidths)
|
||||
quantized, codes, bandwidth, commit_loss = self.quantizer(
|
||||
e, self.frame_rate, bw
|
||||
)
|
||||
speech_hat = self.decoder(quantized)
|
||||
else:
|
||||
speech_hat = self._cache
|
||||
|
||||
# store cache
|
||||
if self.training and self.cache_generator_outputs and not reuse_cache:
|
||||
self._cache = speech_hat
|
||||
|
||||
# calculate discriminator outputs
|
||||
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
|
||||
y_hat, fmap_hat = self.multi_scale_stft_discriminator(
|
||||
speech_hat.contiguous().detach()
|
||||
)
|
||||
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
||||
speech.contiguous(),
|
||||
speech_hat.contiguous().detach(),
|
||||
)
|
||||
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
||||
speech.contiguous(),
|
||||
speech_hat.contiguous().detach(),
|
||||
)
|
||||
# calculate losses
|
||||
with autocast(enabled=False):
|
||||
loss = loss_dis(
|
||||
y,
|
||||
y_hat,
|
||||
fmap,
|
||||
fmap_hat,
|
||||
y_p,
|
||||
y_p_hat,
|
||||
fmap_p,
|
||||
fmap_p_hat,
|
||||
y_s,
|
||||
y_s_hat,
|
||||
fmap_s,
|
||||
fmap_s_hat,
|
||||
global_step,
|
||||
args=self.params,
|
||||
)
|
||||
stats = dict(
|
||||
discriminator_loss=loss.item(),
|
||||
)
|
||||
|
||||
# reset cache
|
||||
if reuse_cache or not self.training:
|
||||
self._cache = None
|
||||
|
||||
return loss, stats
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
global_step: int,
|
||||
return_sample: bool,
|
||||
forward_generator: bool,
|
||||
):
|
||||
if forward_generator:
|
||||
return self._forward_generator(
|
||||
speech=speech,
|
||||
speech_lengths=speech_lengths,
|
||||
global_step=global_step,
|
||||
return_sample=return_sample,
|
||||
)
|
||||
else:
|
||||
return self._forward_discriminator(
|
||||
speech=speech,
|
||||
speech_lengths=speech_lengths,
|
||||
global_step=global_step,
|
||||
)
|
||||
|
||||
def encode(self, x, target_bw=None, st=None):
|
||||
e = self.encoder(x)
|
||||
if target_bw is None:
|
||||
bw = self.target_bandwidths[-1]
|
||||
else:
|
||||
bw = target_bw
|
||||
if st is None:
|
||||
st = 0
|
||||
codes = self.quantizer.encode(e, self.frame_rate, bw, st)
|
||||
return codes
|
||||
|
||||
def decode(self, codes):
|
||||
quantized = self.quantizer.decode(codes)
|
||||
o = self.decoder(quantized)
|
||||
return o
|
298
egs/libritts/CODEC/encodec/loss.py
Normal file
298
egs/libritts/CODEC/encodec/loss.py
Normal file
@ -0,0 +1,298 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchaudio.transforms import MelSpectrogram
|
||||
|
||||
|
||||
def adversarial_g_loss(y_disc_gen):
|
||||
"""Hinge loss"""
|
||||
loss = 0.0
|
||||
for i in range(len(y_disc_gen)):
|
||||
stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze()
|
||||
loss += stft_loss
|
||||
return loss / len(y_disc_gen)
|
||||
|
||||
|
||||
def feature_loss(fmap_r, fmap_gen):
|
||||
loss = 0.0
|
||||
for i in range(len(fmap_r)):
|
||||
for j in range(len(fmap_r[i])):
|
||||
stft_loss = (
|
||||
(fmap_r[i][j] - fmap_gen[i][j]).abs() / (fmap_r[i][j].abs().mean())
|
||||
).mean()
|
||||
loss += stft_loss
|
||||
return loss / (len(fmap_r) * len(fmap_r[0]))
|
||||
|
||||
|
||||
def sim_loss(y_disc_r, y_disc_gen):
|
||||
loss = 0.0
|
||||
for i in range(len(y_disc_r)):
|
||||
loss += F.mse_loss(y_disc_r[i], y_disc_gen[i])
|
||||
return loss / len(y_disc_r)
|
||||
|
||||
|
||||
# def sisnr_loss(x, s, eps=1e-8):
|
||||
# """
|
||||
# calculate training loss
|
||||
# input:
|
||||
# x: separated signal, N x S tensor, estimate value
|
||||
# s: reference signal, N x S tensor, True value
|
||||
# Return:
|
||||
# sisnr: N tensor
|
||||
# """
|
||||
# if x.shape != s.shape:
|
||||
# if x.shape[-1] > s.shape[-1]:
|
||||
# x = x[:, :s.shape[-1]]
|
||||
# else:
|
||||
# s = s[:, :x.shape[-1]]
|
||||
# def l2norm(mat, keepdim=False):
|
||||
# return torch.norm(mat, dim=-1, keepdim=keepdim)
|
||||
# if x.shape != s.shape:
|
||||
# raise RuntimeError(
|
||||
# "Dimention mismatch when calculate si-snr, {} vs {}".format(
|
||||
# x.shape, s.shape))
|
||||
# x_zm = x - torch.mean(x, dim=-1, keepdim=True)
|
||||
# s_zm = s - torch.mean(s, dim=-1, keepdim=True)
|
||||
# t = torch.sum(
|
||||
# x_zm * s_zm, dim=-1,
|
||||
# keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
|
||||
# loss = -20. * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))
|
||||
# return torch.sum(loss) / x.shape[0]
|
||||
|
||||
|
||||
def reconstruction_loss(x, G_x, args, eps=1e-7):
|
||||
# NOTE (lsx): hard-coded now
|
||||
L = args.lambda_wav * F.mse_loss(x, G_x) # wav L1 loss
|
||||
# loss_sisnr = sisnr_loss(G_x, x) #
|
||||
# L += 0.01*loss_sisnr
|
||||
# 2^6=64 -> 2^10=1024
|
||||
# NOTE (lsx): add 2^11
|
||||
for i in range(6, 12):
|
||||
# for i in range(5, 12): # Encodec setting
|
||||
s = 2**i
|
||||
melspec = MelSpectrogram(
|
||||
sample_rate=args.sr,
|
||||
n_fft=max(s, 512),
|
||||
win_length=s,
|
||||
hop_length=s // 4,
|
||||
n_mels=64,
|
||||
wkwargs={"device": args.device},
|
||||
).to(args.device)
|
||||
S_x = melspec(x)
|
||||
S_G_x = melspec(G_x)
|
||||
l1_loss = (S_x - S_G_x).abs().mean()
|
||||
l2_loss = (
|
||||
((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean(
|
||||
dim=-2
|
||||
)
|
||||
** 0.5
|
||||
).mean()
|
||||
|
||||
alpha = (s / 2) ** 0.5
|
||||
L += l1_loss + alpha * l2_loss
|
||||
return L
|
||||
|
||||
|
||||
def criterion_d(
|
||||
y_disc_r,
|
||||
y_disc_gen,
|
||||
fmap_r_det,
|
||||
fmap_gen_det,
|
||||
y_df_hat_r,
|
||||
y_df_hat_g,
|
||||
fmap_f_r,
|
||||
fmap_f_g,
|
||||
y_ds_hat_r,
|
||||
y_ds_hat_g,
|
||||
fmap_s_r,
|
||||
fmap_s_g,
|
||||
):
|
||||
"""Hinge Loss"""
|
||||
loss = 0.0
|
||||
loss1 = 0.0
|
||||
loss2 = 0.0
|
||||
loss3 = 0.0
|
||||
for i in range(len(y_disc_r)):
|
||||
loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[i]).mean()
|
||||
for i in range(len(y_df_hat_r)):
|
||||
loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[i]).mean()
|
||||
for i in range(len(y_ds_hat_r)):
|
||||
loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[i]).mean()
|
||||
|
||||
loss = (
|
||||
loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 / len(y_ds_hat_r)
|
||||
) / 3.0
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def criterion_g(
|
||||
commit_loss,
|
||||
x,
|
||||
G_x,
|
||||
fmap_r,
|
||||
fmap_gen,
|
||||
y_disc_r,
|
||||
y_disc_gen,
|
||||
y_df_hat_r,
|
||||
y_df_hat_g,
|
||||
fmap_f_r,
|
||||
fmap_f_g,
|
||||
y_ds_hat_r,
|
||||
y_ds_hat_g,
|
||||
fmap_s_r,
|
||||
fmap_s_g,
|
||||
args,
|
||||
):
|
||||
adv_g_loss = adversarial_g_loss(y_disc_gen)
|
||||
feat_loss = (
|
||||
feature_loss(fmap_r, fmap_gen)
|
||||
+ sim_loss(y_disc_r, y_disc_gen)
|
||||
+ feature_loss(fmap_f_r, fmap_f_g)
|
||||
+ sim_loss(y_df_hat_r, y_df_hat_g)
|
||||
+ feature_loss(fmap_s_r, fmap_s_g)
|
||||
+ sim_loss(y_ds_hat_r, y_ds_hat_g)
|
||||
) / 3.0
|
||||
rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args)
|
||||
total_loss = (
|
||||
args.lambda_com * commit_loss
|
||||
+ args.lambda_adv * adv_g_loss
|
||||
+ args.lambda_feat * feat_loss
|
||||
+ args.lambda_rec * rec_loss
|
||||
)
|
||||
return total_loss, adv_g_loss, feat_loss, rec_loss
|
||||
|
||||
|
||||
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
||||
if global_step < threshold:
|
||||
weight = value
|
||||
return weight
|
||||
|
||||
|
||||
def adopt_dis_weight(weight, global_step, threshold=0, value=0.0):
|
||||
# 0,3,6,9,13....这些时间步,不更新dis
|
||||
if global_step % 3 == 0:
|
||||
weight = value
|
||||
return weight
|
||||
|
||||
|
||||
def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
|
||||
if last_layer is not None:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
else:
|
||||
print("last_layer cannot be none")
|
||||
assert 1 == 2
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 1.0, 1.0).detach()
|
||||
d_weight = d_weight * args.lambda_adv
|
||||
return d_weight
|
||||
|
||||
|
||||
def loss_g(
|
||||
codebook_loss,
|
||||
speech,
|
||||
speech_hat,
|
||||
fmap,
|
||||
fmap_hat,
|
||||
y,
|
||||
y_hat,
|
||||
global_step,
|
||||
y_df,
|
||||
y_df_hat,
|
||||
y_ds,
|
||||
y_ds_hat,
|
||||
fmap_f,
|
||||
fmap_f_hat,
|
||||
fmap_s,
|
||||
fmap_s_hat,
|
||||
args=None,
|
||||
):
|
||||
"""
|
||||
args:
|
||||
codebook_loss: commit loss.
|
||||
speech: ground-truth wav.
|
||||
speech_hat: reconstructed wav.
|
||||
fmap: real stft-D feature map.
|
||||
fmap_hat: fake stft-D feature map.
|
||||
y: real stft-D logits.
|
||||
y_hat: fake stft-D logits.
|
||||
global_step: global training step.
|
||||
y_df: real MPD logits.
|
||||
y_df_hat: fake MPD logits.
|
||||
y_ds: real MSD logits.
|
||||
y_ds_hat: fake MSD logits.
|
||||
fmap_f: real MPD feature map.
|
||||
fmap_f_hat: fake MPD feature map.
|
||||
fmap_s: real MSD feature map.
|
||||
fmap_s_hat: fake MSD feature map.
|
||||
"""
|
||||
rec_loss = reconstruction_loss(speech.contiguous(), speech_hat.contiguous(), args)
|
||||
adv_g_loss = adversarial_g_loss(y_hat)
|
||||
adv_mpd_loss = adversarial_g_loss(y_df_hat)
|
||||
adv_msd_loss = adversarial_g_loss(y_ds_hat)
|
||||
adv_loss = (
|
||||
adv_g_loss + adv_mpd_loss + adv_msd_loss
|
||||
) / 3.0 # NOTE(lsx): need to divide by 3?
|
||||
feat_loss = feature_loss(
|
||||
fmap, fmap_hat
|
||||
) # + sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits?
|
||||
feat_loss_mpd = feature_loss(
|
||||
fmap_f, fmap_f_hat
|
||||
) # + sim_loss(y_df_hat_r, y_df_hat_g)
|
||||
feat_loss_msd = feature_loss(
|
||||
fmap_s, fmap_s_hat
|
||||
) # + sim_loss(y_ds_hat_r, y_ds_hat_g)
|
||||
feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
|
||||
d_weight = torch.tensor(1.0)
|
||||
|
||||
disc_factor = adopt_weight(
|
||||
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
|
||||
)
|
||||
if disc_factor == 0.0:
|
||||
fm_loss_wt = 0
|
||||
else:
|
||||
fm_loss_wt = args.lambda_feat
|
||||
|
||||
loss = (
|
||||
rec_loss
|
||||
+ d_weight * disc_factor * adv_loss
|
||||
+ fm_loss_wt * feat_loss_tot
|
||||
+ args.lambda_com * codebook_loss
|
||||
)
|
||||
return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
|
||||
|
||||
|
||||
def loss_dis(
|
||||
y,
|
||||
y_hat,
|
||||
fmap,
|
||||
fmap_hat,
|
||||
y_df,
|
||||
y_df_hat,
|
||||
fmap_f,
|
||||
fmap_f_hat,
|
||||
y_ds,
|
||||
y_ds_hat,
|
||||
fmap_s,
|
||||
fmap_s_hat,
|
||||
global_step,
|
||||
args,
|
||||
):
|
||||
disc_factor = adopt_weight(
|
||||
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
|
||||
)
|
||||
d_loss = disc_factor * criterion_d(
|
||||
y,
|
||||
y_hat,
|
||||
fmap,
|
||||
fmap_hat,
|
||||
y_df,
|
||||
y_df_hat,
|
||||
fmap_f,
|
||||
fmap_f_hat,
|
||||
y_ds,
|
||||
y_ds_hat,
|
||||
fmap_s,
|
||||
fmap_s_hat,
|
||||
)
|
||||
return d_loss
|
229
egs/libritts/CODEC/encodec/models/discriminators.py
Normal file
229
egs/libritts/CODEC/encodec/models/discriminators.py
Normal file
@ -0,0 +1,229 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from einops import rearrange
|
||||
from utils import get_2d_padding, get_padding
|
||||
|
||||
from ..modules import NormConv1d, NormConv2d
|
||||
|
||||
|
||||
class DiscriminatorP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
period,
|
||||
kernel_size=5,
|
||||
stride=3,
|
||||
activation: str = "LeakyReLU",
|
||||
activation_params: dict = {"negative_slope": 0.2},
|
||||
):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
|
||||
self.period = period
|
||||
self.activation = getattr(torch.nn, activation)(**activation_params)
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
NormConv2d(
|
||||
1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0)
|
||||
),
|
||||
NormConv2d(
|
||||
32,
|
||||
32,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
),
|
||||
NormConv2d(
|
||||
32,
|
||||
32,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
),
|
||||
NormConv2d(
|
||||
32,
|
||||
32,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
),
|
||||
NormConv2d(32, 32, (kernel_size, 1), 1, padding=(2, 0)),
|
||||
]
|
||||
)
|
||||
self.conv_post = NormConv2d(32, 1, (3, 1), 1, padding=(1, 0))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = self.activation(x)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class DiscriminatorS(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
activation: str = "LeakyReLU",
|
||||
activation_params: dict = {"negative_slope": 0.2},
|
||||
):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
self.activation = getattr(torch.nn, activation)(**activation_params)
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
NormConv1d(1, 32, 15, 1, padding=7),
|
||||
NormConv1d(32, 32, 41, 2, groups=4, padding=20),
|
||||
NormConv1d(32, 32, 41, 2, groups=16, padding=20),
|
||||
NormConv1d(32, 32, 41, 4, groups=16, padding=20),
|
||||
NormConv1d(32, 32, 41, 4, groups=16, padding=20),
|
||||
NormConv1d(32, 32, 41, 1, groups=16, padding=20),
|
||||
NormConv1d(32, 32, 5, 1, padding=2),
|
||||
]
|
||||
)
|
||||
self.conv_post = NormConv1d(32, 1, 3, 1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = self.activation(x)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
return x, fmap
|
||||
|
||||
|
||||
class DiscriminatorSTFT(nn.Module):
|
||||
"""STFT sub-discriminator.
|
||||
Args:
|
||||
filters (int): Number of filters in convolutions
|
||||
in_channels (int): Number of input channels. Default: 1
|
||||
out_channels (int): Number of output channels. Default: 1
|
||||
n_fft (int): Size of FFT for each scale. Default: 1024
|
||||
hop_length (int): Length of hop between STFT windows for each scale. Default: 256
|
||||
kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
|
||||
stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
|
||||
dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
|
||||
win_length (int): Window size for each scale. Default: 1024
|
||||
normalized (bool): Whether to normalize by magnitude after stft. Default: True
|
||||
norm (str): Normalization method. Default: `'weight_norm'`
|
||||
activation (str): Activation function. Default: `'LeakyReLU'`
|
||||
activation_params (dict): Parameters to provide to the activation function.
|
||||
growth (int): Growth factor for the filters. Default: 1
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_filters: int,
|
||||
in_channels: int = 1,
|
||||
out_channels: int = 1,
|
||||
n_fft: int = 1024,
|
||||
hop_length: int = 256,
|
||||
win_length: int = 1024,
|
||||
max_filters: int = 1024,
|
||||
filters_scale: int = 1,
|
||||
kernel_size: Tuple[int, int] = (3, 9),
|
||||
dilations: List[int] = [1, 2, 4],
|
||||
stride: Tuple[int, int] = (1, 2),
|
||||
normalized: bool = True,
|
||||
norm: str = "weight_norm",
|
||||
activation: str = "LeakyReLU",
|
||||
activation_params: dict = {"negative_slope": 0.2},
|
||||
):
|
||||
super().__init__()
|
||||
assert len(kernel_size) == 2
|
||||
assert len(stride) == 2
|
||||
self.filters = n_filters
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.normalized = normalized
|
||||
self.activation = getattr(torch.nn, activation)(**activation_params)
|
||||
self.spec_transform = torchaudio.transforms.Spectrogram(
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window_fn=torch.hann_window,
|
||||
normalized=self.normalized,
|
||||
center=False,
|
||||
pad_mode=None,
|
||||
power=None,
|
||||
)
|
||||
spec_channels = 2 * self.in_channels
|
||||
self.convs = nn.ModuleList()
|
||||
self.convs.append(
|
||||
NormConv2d(
|
||||
spec_channels,
|
||||
self.filters,
|
||||
kernel_size=kernel_size,
|
||||
padding=get_2d_padding(kernel_size),
|
||||
)
|
||||
)
|
||||
in_chs = min(filters_scale * self.filters, max_filters)
|
||||
for i, dilation in enumerate(dilations):
|
||||
out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
|
||||
self.convs.append(
|
||||
NormConv2d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=(dilation, 1),
|
||||
padding=get_2d_padding(kernel_size, (dilation, 1)),
|
||||
norm=norm,
|
||||
)
|
||||
)
|
||||
in_chs = out_chs
|
||||
out_chs = min(
|
||||
(filters_scale ** (len(dilations) + 1)) * self.filters, max_filters
|
||||
)
|
||||
self.convs.append(
|
||||
NormConv2d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=(kernel_size[0], kernel_size[0]),
|
||||
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
|
||||
norm=norm,
|
||||
)
|
||||
)
|
||||
self.conv_post = NormConv2d(
|
||||
out_chs,
|
||||
self.out_channels,
|
||||
kernel_size=(kernel_size[0], kernel_size[0]),
|
||||
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
|
||||
norm=norm,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
fmap = []
|
||||
# print('x ', x.shape)
|
||||
z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
|
||||
# print('z ', z.shape)
|
||||
z = torch.cat([z.real, z.imag], dim=1)
|
||||
# print('cat_z ', z.shape)
|
||||
z = rearrange(z, "b c w t -> b c t w")
|
||||
for i, layer in enumerate(self.convs):
|
||||
z = layer(z)
|
||||
z = self.activation(z)
|
||||
# print('z i', i, z.shape)
|
||||
fmap.append(z)
|
||||
z = self.conv_post(z)
|
||||
# print('logit ', z.shape)
|
||||
return z, fmap
|
12
egs/libritts/CODEC/encodec/models/utils.py
Normal file
12
egs/libritts/CODEC/encodec/models/utils.py
Normal file
@ -0,0 +1,12 @@
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1) -> int:
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)):
|
||||
return (
|
||||
((kernel_size[0] - 1) * dilation[0]) // 2,
|
||||
((kernel_size[1] - 1) * dilation[1]) // 2,
|
||||
)
|
20
egs/libritts/CODEC/encodec/modules/__init__.py
Normal file
20
egs/libritts/CODEC/encodec/modules/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Torch modules."""
|
||||
# flake8: noqa
|
||||
from .conv import (
|
||||
NormConv1d,
|
||||
NormConv2d,
|
||||
NormConvTranspose1d,
|
||||
NormConvTranspose2d,
|
||||
SConv1d,
|
||||
SConvTranspose1d,
|
||||
pad1d,
|
||||
unpad1d,
|
||||
)
|
||||
from .lstm import SLSTM
|
||||
from .seanet import SEANetDecoder, SEANetEncoder
|
||||
from .transformer import StreamingTransformerEncoder
|
334
egs/libritts/CODEC/encodec/modules/conv.py
Normal file
334
egs/libritts/CODEC/encodec/modules/conv.py
Normal file
@ -0,0 +1,334 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Convolutional layers wrappers and utilities."""
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import spectral_norm, weight_norm
|
||||
|
||||
from .norm import ConvLayerNorm
|
||||
|
||||
CONV_NORMALIZATIONS = frozenset(
|
||||
[
|
||||
"none",
|
||||
"weight_norm",
|
||||
"spectral_norm",
|
||||
"time_layer_norm",
|
||||
"layer_norm",
|
||||
"time_group_norm",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
|
||||
assert norm in CONV_NORMALIZATIONS
|
||||
if norm == "weight_norm":
|
||||
return weight_norm(module)
|
||||
elif norm == "spectral_norm":
|
||||
return spectral_norm(module)
|
||||
else:
|
||||
# We already check was in CONV_NORMALIZATION, so any other choice
|
||||
# doesn't need reparametrization.
|
||||
return module
|
||||
|
||||
|
||||
def get_norm_module(
|
||||
module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
|
||||
) -> nn.Module:
|
||||
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
||||
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
||||
"""
|
||||
assert norm in CONV_NORMALIZATIONS
|
||||
if norm == "layer_norm":
|
||||
assert isinstance(module, nn.modules.conv._ConvNd)
|
||||
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
||||
elif norm == "time_group_norm":
|
||||
if causal:
|
||||
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
||||
assert isinstance(module, nn.modules.conv._ConvNd)
|
||||
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
||||
else:
|
||||
return nn.Identity()
|
||||
|
||||
|
||||
def get_extra_padding_for_conv1d(
|
||||
x: Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
||||
) -> int:
|
||||
"""See `pad_for_conv1d`."""
|
||||
length = x.shape[-1]
|
||||
n_frames = (length - kernel_size + padding_total) / stride + 1
|
||||
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
||||
return ideal_length - length
|
||||
|
||||
|
||||
def pad_for_conv1d(x: Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
||||
"""Pad for a convolution to make sure that the last window is full.
|
||||
Extra padding is added at the end. This is required to ensure that we can rebuild
|
||||
an output of the same length, as otherwise, even with padding, some time steps
|
||||
might get removed.
|
||||
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
||||
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
||||
1 2 3 # (output frames of a convolution, last 0 is never used)
|
||||
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
||||
1 2 3 4 # once you removed padding, we are missing one time step !
|
||||
"""
|
||||
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
||||
return F.pad(x, (0, extra_padding))
|
||||
|
||||
|
||||
def pad1d(
|
||||
x: Tensor,
|
||||
paddings: Tuple[int, int],
|
||||
mode: str = "zero",
|
||||
value: float = 0.0,
|
||||
):
|
||||
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
||||
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
||||
"""
|
||||
length = x.shape[-1]
|
||||
padding_left, padding_right = paddings
|
||||
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
||||
if mode == "reflect":
|
||||
max_pad = max(padding_left, padding_right)
|
||||
extra_pad = 0
|
||||
if length <= max_pad:
|
||||
extra_pad = max_pad - length + 1
|
||||
x = F.pad(x, (0, extra_pad))
|
||||
padded = F.pad(x, paddings, mode, value)
|
||||
end = padded.shape[-1] - extra_pad
|
||||
return padded[..., :end]
|
||||
else:
|
||||
return F.pad(x, paddings, mode, value)
|
||||
|
||||
|
||||
def unpad1d(x: Tensor, paddings: Tuple[int, int]):
|
||||
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
||||
padding_left, padding_right = paddings
|
||||
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
||||
assert (padding_left + padding_right) <= x.shape[-1]
|
||||
end = x.shape[-1] - padding_right
|
||||
return x[..., padding_left:end]
|
||||
|
||||
|
||||
class NormConv1d(nn.Module):
|
||||
"""Wrapper around Conv1d and normalization applied to this conv
|
||||
to provide a uniform interface across normalization approaches.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
causal: bool = False,
|
||||
norm: str = "none",
|
||||
norm_kwargs: Dict[str, Any] = {},
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
||||
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
||||
self.norm_type = norm
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class NormConv2d(nn.Module):
|
||||
"""Wrapper around Conv2d and normalization applied to this conv
|
||||
to provide a uniform interface across normalization approaches.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
norm: str = "none",
|
||||
norm_kwargs: Dict[str, Any] = {},
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
||||
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
||||
self.norm_type = norm
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class NormConvTranspose1d(nn.Module):
|
||||
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
||||
to provide a uniform interface across normalization approaches.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
causal: bool = False,
|
||||
norm: str = "none",
|
||||
norm_kwargs: Dict[str, Any] = {},
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.convtr = apply_parametrization_norm(
|
||||
nn.ConvTranspose1d(*args, **kwargs), norm
|
||||
)
|
||||
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
||||
self.norm_type = norm
|
||||
|
||||
def forward(self, x):
|
||||
x = self.convtr(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class NormConvTranspose2d(nn.Module):
|
||||
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
||||
to provide a uniform interface across normalization approaches.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
norm: str = "none",
|
||||
norm_kwargs: Dict[str, Any] = {},
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.convtr = apply_parametrization_norm(
|
||||
nn.ConvTranspose2d(*args, **kwargs), norm
|
||||
)
|
||||
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.convtr(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class SConv1d(nn.Module):
|
||||
"""Conv1d with some builtin handling of asymmetric or causal padding
|
||||
and normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
causal: bool = False,
|
||||
norm: str = "none",
|
||||
norm_kwargs: Dict[str, Any] = {},
|
||||
pad_mode: str = "reflect",
|
||||
):
|
||||
super().__init__()
|
||||
# warn user on unusual setup between dilation and stride
|
||||
if stride > 1 and dilation > 1:
|
||||
logging.warning(
|
||||
"SConv1d has been initialized with stride > 1 and dilation > 1"
|
||||
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
|
||||
)
|
||||
self.conv = NormConv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
causal=causal,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_kwargs,
|
||||
)
|
||||
self.causal = causal
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
def forward(self, x):
|
||||
B, C, T = x.shape
|
||||
kernel_size = self.conv.conv.kernel_size[0]
|
||||
stride = self.conv.conv.stride[0]
|
||||
dilation = self.conv.conv.dilation[0]
|
||||
padding_total = (kernel_size - 1) * dilation - (stride - 1)
|
||||
extra_padding = get_extra_padding_for_conv1d(
|
||||
x, kernel_size, stride, padding_total
|
||||
)
|
||||
if self.causal:
|
||||
# Left padding for causal
|
||||
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
||||
else:
|
||||
# Asymmetric padding required for odd strides
|
||||
padding_right = padding_total // 2
|
||||
padding_left = padding_total - padding_right
|
||||
x = pad1d(
|
||||
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
|
||||
)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class SConvTranspose1d(nn.Module):
|
||||
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
||||
and normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
causal: bool = False,
|
||||
norm: str = "none",
|
||||
trim_right_ratio: float = 1.0,
|
||||
norm_kwargs: Dict[str, Any] = {},
|
||||
):
|
||||
super().__init__()
|
||||
self.convtr = NormConvTranspose1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
causal=causal,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_kwargs,
|
||||
)
|
||||
self.causal = causal
|
||||
self.trim_right_ratio = trim_right_ratio
|
||||
assert (
|
||||
self.causal or self.trim_right_ratio == 1.0
|
||||
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
||||
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
|
||||
|
||||
def forward(self, x):
|
||||
kernel_size = self.convtr.convtr.kernel_size[0]
|
||||
stride = self.convtr.convtr.stride[0]
|
||||
padding_total = kernel_size - stride
|
||||
|
||||
y = self.convtr(x)
|
||||
|
||||
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
||||
# removed at the very end, when keeping only the right length for the output,
|
||||
# as removing it here would require also passing the length at the matching layer
|
||||
# in the encoder.
|
||||
if self.causal:
|
||||
# Trim the padding on the right according to the specified ratio
|
||||
# if trim_right_ratio = 1.0, trim everything from right
|
||||
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
||||
padding_left = padding_total - padding_right
|
||||
y = unpad1d(y, (padding_left, padding_right))
|
||||
else:
|
||||
# Asymmetric padding required for odd strides
|
||||
padding_right = padding_total // 2
|
||||
padding_left = padding_total - padding_right
|
||||
y = unpad1d(y, (padding_left, padding_right))
|
||||
return y
|
27
egs/libritts/CODEC/encodec/modules/lstm.py
Normal file
27
egs/libritts/CODEC/encodec/modules/lstm.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""LSTM layers module."""
|
||||
from torch import nn
|
||||
|
||||
|
||||
class SLSTM(nn.Module):
|
||||
"""
|
||||
LSTM without worrying about the hidden state, nor the layout of the data.
|
||||
Expects input as convolutional layout.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
||||
super().__init__()
|
||||
self.skip = skip
|
||||
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(2, 0, 1)
|
||||
y, _ = self.lstm(x)
|
||||
if self.skip:
|
||||
y = y + x
|
||||
y = y.permute(1, 2, 0)
|
||||
return y
|
28
egs/libritts/CODEC/encodec/modules/norm.py
Normal file
28
egs/libritts/CODEC/encodec/modules/norm.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Normalization modules."""
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ConvLayerNorm(nn.LayerNorm):
|
||||
"""
|
||||
Convolution-friendly LayerNorm that moves channels to last dimensions
|
||||
before running the normalization and moves them back to original position right after.
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape: Union[int, List[int], torch.Size], **kwargs):
|
||||
super().__init__(normalized_shape, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = einops.rearrange(x, "b ... t -> b t ...")
|
||||
x = super().forward(x)
|
||||
x = einops.rearrange(x, "b t ... -> b ... t")
|
||||
return
|
368
egs/libritts/CODEC/encodec/modules/seanet.py
Normal file
368
egs/libritts/CODEC/encodec/modules/seanet.py
Normal file
@ -0,0 +1,368 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Encodec SEANet-based encoder and decoder implementation."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from modules import SLSTM, SConv1d, SConvTranspose1d
|
||||
|
||||
|
||||
class SEANetResnetBlock(nn.Module):
|
||||
"""Residual block from SEANet model.
|
||||
Args:
|
||||
dim (int): Dimension of the input/output
|
||||
kernel_sizes (list): List of kernel sizes for the convolutions.
|
||||
dilations (list): List of dilations for the convolutions.
|
||||
activation (str): Activation function.
|
||||
activation_params (dict): Parameters to provide to the activation function
|
||||
norm (str): Normalization method.
|
||||
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
||||
causal (bool): Whether to use fully causal convolution.
|
||||
pad_mode (str): Padding mode for the convolutions.
|
||||
compress (int): Reduced dimensionality in residual branches (from Demucs v3)
|
||||
true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
kernel_sizes: List[int] = [3, 1],
|
||||
dilations: List[int] = [1, 1],
|
||||
activation: str = "ELU",
|
||||
activation_params: Dict = {"alpha": 1.0},
|
||||
norm: str = "weight_norm",
|
||||
norm_params: Dict[str, Any] = {},
|
||||
causal: bool = False,
|
||||
pad_mode: str = "reflect",
|
||||
compress: int = 2,
|
||||
true_skip: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
assert len(kernel_sizes) == len(
|
||||
dilations
|
||||
), "Number of kernel sizes should match number of dilations"
|
||||
act = getattr(nn, activation)
|
||||
hidden = dim // compress
|
||||
block = []
|
||||
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
||||
in_chs = dim if i == 0 else hidden
|
||||
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
||||
block += [
|
||||
act(**activation_params),
|
||||
SConv1d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
),
|
||||
]
|
||||
self.block = nn.Sequential(*block)
|
||||
self.shortcut: nn.Module
|
||||
if true_skip:
|
||||
self.shortcut = nn.Identity()
|
||||
else:
|
||||
self.shortcut = SConv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=1,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.shortcut(x) + self.block(x)
|
||||
|
||||
|
||||
class SEANetEncoder(nn.Module):
|
||||
"""SEANet encoder.
|
||||
Args:
|
||||
channels (int): Audio channels.
|
||||
dimension (int): Intermediate representation dimension.
|
||||
n_filters (int): Base width for the model.
|
||||
n_residual_layers (int): nb of residual layers.
|
||||
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
||||
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
||||
that must match the decoder order
|
||||
activation (str): Activation function.
|
||||
activation_params (dict): Parameters to provide to the activation function
|
||||
norm (str): Normalization method.
|
||||
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
||||
kernel_size (int): Kernel size for the initial convolution.
|
||||
last_kernel_size (int): Kernel size for the initial convolution.
|
||||
residual_kernel_size (int): Kernel size for the residual layers.
|
||||
dilation_base (int): How much to increase the dilation with each layer.
|
||||
causal (bool): Whether to use fully causal convolution.
|
||||
pad_mode (str): Padding mode for the convolutions.
|
||||
true_skip (bool): Whether to use true skip connection or a simple
|
||||
(streamable) convolution as the skip connection in the residual network blocks.
|
||||
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
||||
lstm (int): Number of LSTM layers at the end of the encoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = 1,
|
||||
dimension: int = 128,
|
||||
n_filters: int = 32,
|
||||
n_residual_layers: int = 1,
|
||||
ratios: List[int] = [8, 5, 4, 2],
|
||||
activation: str = "ELU",
|
||||
activation_params: dict = {"alpha": 1.0},
|
||||
norm: str = "weight_norm",
|
||||
norm_params: Dict[str, Any] = {},
|
||||
kernel_size: int = 7,
|
||||
last_kernel_size: int = 7,
|
||||
residual_kernel_size: int = 3,
|
||||
dilation_base: int = 2,
|
||||
causal: bool = False,
|
||||
pad_mode: str = "reflect",
|
||||
true_skip: bool = False,
|
||||
compress: int = 2,
|
||||
lstm: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.dimension = dimension
|
||||
self.n_filters = n_filters
|
||||
self.ratios = list(reversed(ratios))
|
||||
del ratios
|
||||
self.n_residual_layers = n_residual_layers
|
||||
self.hop_length = np.prod(self.ratios) # 计算乘积
|
||||
|
||||
act = getattr(nn, activation)
|
||||
mult = 1
|
||||
model: List[nn.Module] = [
|
||||
SConv1d(
|
||||
channels,
|
||||
mult * n_filters,
|
||||
kernel_size,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
]
|
||||
# Downsample to raw audio scale
|
||||
for i, ratio in enumerate(self.ratios):
|
||||
# Add residual layers
|
||||
for j in range(n_residual_layers):
|
||||
model += [
|
||||
SEANetResnetBlock(
|
||||
mult * n_filters,
|
||||
kernel_sizes=[residual_kernel_size, 1],
|
||||
dilations=[dilation_base**j, 1],
|
||||
norm=norm,
|
||||
norm_params=norm_params,
|
||||
activation=activation,
|
||||
activation_params=activation_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
compress=compress,
|
||||
true_skip=true_skip,
|
||||
)
|
||||
]
|
||||
|
||||
# Add downsampling layers
|
||||
model += [
|
||||
act(**activation_params),
|
||||
SConv1d(
|
||||
mult * n_filters,
|
||||
mult * n_filters * 2,
|
||||
kernel_size=ratio * 2,
|
||||
stride=ratio,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
),
|
||||
]
|
||||
mult *= 2
|
||||
|
||||
if lstm:
|
||||
model += [SLSTM(mult * n_filters, num_layers=lstm)]
|
||||
|
||||
model += [
|
||||
act(**activation_params),
|
||||
SConv1d(
|
||||
mult * n_filters,
|
||||
dimension,
|
||||
last_kernel_size,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
),
|
||||
]
|
||||
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class SEANetDecoder(nn.Module):
|
||||
"""SEANet decoder.
|
||||
Args:
|
||||
channels (int): Audio channels.
|
||||
dimension (int): Intermediate representation dimension.
|
||||
n_filters (int): Base width for the model.
|
||||
n_residual_layers (int): nb of residual layers.
|
||||
ratios (Sequence[int]): kernel size and stride ratios
|
||||
activation (str): Activation function.
|
||||
activation_params (dict): Parameters to provide to the activation function
|
||||
final_activation (str): Final activation function after all convolutions.
|
||||
final_activation_params (dict): Parameters to provide to the activation function
|
||||
norm (str): Normalization method.
|
||||
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
||||
kernel_size (int): Kernel size for the initial convolution.
|
||||
last_kernel_size (int): Kernel size for the initial convolution.
|
||||
residual_kernel_size (int): Kernel size for the residual layers.
|
||||
dilation_base (int): How much to increase the dilation with each layer.
|
||||
causal (bool): Whether to use fully causal convolution.
|
||||
pad_mode (str): Padding mode for the convolutions.
|
||||
true_skip (bool): Whether to use true skip connection or a simple
|
||||
(streamable) convolution as the skip connection in the residual network blocks.
|
||||
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
||||
lstm (int): Number of LSTM layers at the end of the encoder.
|
||||
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
||||
If equal to 1.0, it means that all the trimming is done at the right.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = 1,
|
||||
dimension: int = 128,
|
||||
n_filters: int = 32,
|
||||
n_residual_layers: int = 1,
|
||||
ratios: List[int] = [8, 5, 4, 2],
|
||||
activation: str = "ELU",
|
||||
activation_params: dict = {"alpha": 1.0},
|
||||
final_activation: Optional[str] = None,
|
||||
final_activation_params: Optional[dict] = None,
|
||||
norm: str = "weight_norm",
|
||||
norm_params: Dict[str, Any] = {},
|
||||
kernel_size: int = 7,
|
||||
last_kernel_size: int = 7,
|
||||
residual_kernel_size: int = 3,
|
||||
dilation_base: int = 2,
|
||||
causal: bool = False,
|
||||
pad_mode: str = "reflect",
|
||||
true_skip: bool = False,
|
||||
compress: int = 2,
|
||||
lstm: int = 2,
|
||||
trim_right_ratio: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dimension = dimension
|
||||
self.channels = channels
|
||||
self.n_filters = n_filters
|
||||
self.ratios = ratios
|
||||
del ratios
|
||||
self.n_residual_layers = n_residual_layers
|
||||
self.hop_length = np.prod(self.ratios)
|
||||
|
||||
act = getattr(nn, activation)
|
||||
mult = int(2 ** len(self.ratios))
|
||||
model: List[nn.Module] = [
|
||||
SConv1d(
|
||||
dimension,
|
||||
mult * n_filters,
|
||||
kernel_size,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
]
|
||||
|
||||
if lstm:
|
||||
model += [SLSTM(mult * n_filters, num_layers=lstm)]
|
||||
|
||||
# Upsample to raw audio scale
|
||||
for i, ratio in enumerate(self.ratios):
|
||||
# Add upsampling layers
|
||||
model += [
|
||||
act(**activation_params),
|
||||
SConvTranspose1d(
|
||||
mult * n_filters,
|
||||
mult * n_filters // 2,
|
||||
kernel_size=ratio * 2,
|
||||
stride=ratio,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
trim_right_ratio=trim_right_ratio,
|
||||
),
|
||||
]
|
||||
# Add residual layers
|
||||
for j in range(n_residual_layers):
|
||||
model += [
|
||||
SEANetResnetBlock(
|
||||
mult * n_filters // 2,
|
||||
kernel_sizes=[residual_kernel_size, 1],
|
||||
dilations=[dilation_base**j, 1],
|
||||
activation=activation,
|
||||
activation_params=activation_params,
|
||||
norm=norm,
|
||||
norm_params=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
compress=compress,
|
||||
true_skip=true_skip,
|
||||
)
|
||||
]
|
||||
|
||||
mult //= 2
|
||||
|
||||
# Add final layers
|
||||
model += [
|
||||
act(**activation_params),
|
||||
SConv1d(
|
||||
n_filters,
|
||||
channels,
|
||||
last_kernel_size,
|
||||
norm=norm,
|
||||
norm_kwargs=norm_params,
|
||||
causal=causal,
|
||||
pad_mode=pad_mode,
|
||||
),
|
||||
]
|
||||
# Add optional final activation to decoder (eg. tanh)
|
||||
if final_activation is not None:
|
||||
final_act = getattr(nn, final_activation)
|
||||
final_activation_params = final_activation_params or {}
|
||||
model += [final_act(**final_activation_params)]
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, z):
|
||||
y = self.model(z)
|
||||
return y
|
||||
|
||||
|
||||
def test():
|
||||
import torch
|
||||
|
||||
encoder = SEANetEncoder()
|
||||
decoder = SEANetDecoder()
|
||||
x = torch.randn(1, 1, 24000)
|
||||
z = encoder(x)
|
||||
print("z ", z.shape)
|
||||
assert 1 == 2
|
||||
assert list(z.shape) == [1, 128, 75], z.shape
|
||||
y = decoder(z)
|
||||
assert y.shape == x.shape, (x.shape, y.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
141
egs/libritts/CODEC/encodec/modules/transformer.py
Normal file
141
egs/libritts/CODEC/encodec/modules/transformer.py
Normal file
@ -0,0 +1,141 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""A streamable transformer."""
|
||||
import typing as tp
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
def create_sin_embedding(positions: Tensor, dim: int, max_period: float = 10000):
|
||||
"""Create time embedding for the given positions, target dimension `dim`."""
|
||||
# We aim for BTC format
|
||||
assert dim % 2 == 0
|
||||
half_dim = dim // 2
|
||||
adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1)
|
||||
phase = positions / (max_period ** (adim / (half_dim - 1)))
|
||||
return torch.cat(
|
||||
[
|
||||
torch.cos(phase),
|
||||
torch.sin(phase),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer):
|
||||
def forward(self, x: Tensor, x_past: Tensor, past_context: int): # type: ignore
|
||||
if self.norm_first:
|
||||
sa_input = self.norm1(x)
|
||||
x = x + self._sa_block(sa_input, x_past, past_context)
|
||||
x = x + self._ff_block(self.norm2(x))
|
||||
else:
|
||||
sa_input = x
|
||||
x = self.norm1(x + self._sa_block(sa_input, x_past, past_context))
|
||||
x = self.norm2(x + self._ff_block(x))
|
||||
|
||||
return x, sa_input
|
||||
|
||||
# self-attention block
|
||||
def _sa_block(self, x: Tensor, x_past: Tensor, past_context: int): # type: ignore
|
||||
_, T, _ = x.shape
|
||||
_, H, _ = x_past.shape
|
||||
|
||||
queries = x
|
||||
keys = torch.cat([x_past, x], dim=1)
|
||||
values = keys
|
||||
|
||||
queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1)
|
||||
keys_pos = torch.arange(T + H, device=x.device).view(1, -1)
|
||||
delta = queries_pos - keys_pos
|
||||
valid_access = (delta >= 0) & (delta <= past_context)
|
||||
x = self.self_attn(
|
||||
queries, keys, values, attn_mask=~valid_access, need_weights=False
|
||||
)[0]
|
||||
return self.dropout1(x)
|
||||
|
||||
|
||||
class StreamingTransformerEncoder(nn.Module):
|
||||
"""TransformerEncoder with streaming support.
|
||||
|
||||
Args:
|
||||
dim (int): dimension of the data.
|
||||
hidden_scale (int): intermediate dimension of FF module is this times the dimension.
|
||||
num_heads (int): number of heads.
|
||||
num_layers (int): number of layers.
|
||||
max_period (float): maxium period of cosines in the positional embedding.
|
||||
past_context (int or None): receptive field for the causal mask, infinite if None.
|
||||
gelu (bool): if true uses GeLUs, otherwise use ReLUs.
|
||||
norm_in (bool): normalize the input.
|
||||
dropout (float): dropout probability.
|
||||
**kwargs: See `nn.TransformerEncoderLayer`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
hidden_scale: float = 4.0,
|
||||
num_heads: int = 8,
|
||||
num_layers: int = 5,
|
||||
max_period: float = 10000,
|
||||
past_context: int = 1000,
|
||||
gelu: bool = True,
|
||||
norm_in: bool = True,
|
||||
dropout: float = 0.0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
hidden_dim = int(dim * hidden_scale)
|
||||
|
||||
self.max_period = max_period
|
||||
self.past_context = past_context
|
||||
activation: Any = F.gelu if gelu else F.relu
|
||||
|
||||
self.norm_in: nn.Module
|
||||
if norm_in:
|
||||
self.norm_in = nn.LayerNorm(dim)
|
||||
else:
|
||||
self.norm_in = nn.Identity()
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for idx in range(num_layers):
|
||||
self.layers.append(
|
||||
StreamingTransformerEncoderLayer(
|
||||
dim,
|
||||
num_heads,
|
||||
hidden_dim,
|
||||
activation=activation,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
states: Optional[List[Tensor]] = None,
|
||||
offset: Union[int, Tensor] = 0,
|
||||
):
|
||||
B, T, C = x.shape
|
||||
if states is None:
|
||||
states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))]
|
||||
|
||||
positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset
|
||||
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period)
|
||||
|
||||
new_state: List[Tensor] = []
|
||||
x = self.norm_in(x)
|
||||
x = x + pos_emb
|
||||
|
||||
for layer_state, layer in zip(states, self.layers):
|
||||
x, new_layer_state = layer(x, layer_state, self.past_context)
|
||||
new_layer_state = torch.cat([layer_state, new_layer_state], dim=1)
|
||||
new_state.append(new_layer_state[:, -self.past_context :, :])
|
||||
return x, new_state, offset + T
|
7
egs/libritts/CODEC/encodec/quantization/__init__.py
Normal file
7
egs/libritts/CODEC/encodec/quantization/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# flake8: noqa
|
||||
from .vq import QuantizedResult, ResidualVectorQuantizer
|
311
egs/libritts/CODEC/encodec/quantization/ac.py
Normal file
311
egs/libritts/CODEC/encodec/quantization/ac.py
Normal file
@ -0,0 +1,311 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Arithmetic coder."""
|
||||
import io
|
||||
import math
|
||||
import random
|
||||
from typing import IO, Any, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from ..binary import BitPacker, BitUnpacker
|
||||
|
||||
|
||||
def build_stable_quantized_cdf(
|
||||
pdf: Tensor,
|
||||
total_range_bits: int,
|
||||
roundoff: float = 1e-8,
|
||||
min_range: int = 2,
|
||||
check: bool = True,
|
||||
) -> Tensor:
|
||||
"""Turn the given PDF into a quantized CDF that splits
|
||||
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
|
||||
to the PDF.
|
||||
|
||||
Args:
|
||||
pdf (Tensor): probability distribution, shape should be `[N]`.
|
||||
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
|
||||
during the coding process is `[0, 2 ** total_range_bits - 1]`.
|
||||
roundoff (float): will round the pdf up to that level to remove difference coming
|
||||
from e.g. evaluating the Language Model on different architectures.
|
||||
min_range (int): minimum range width. Should always be at least 2 for numerical
|
||||
stability. Use this to avoid pathological behavior is a value
|
||||
that is expected to be rare actually happens in real life.
|
||||
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
|
||||
"""
|
||||
pdf = pdf.detach()
|
||||
if roundoff:
|
||||
pdf = (pdf / roundoff).floor() * roundoff
|
||||
# interpolate with uniform distribution to achieve desired minimum probability.
|
||||
total_range = 2**total_range_bits
|
||||
cardinality = len(pdf)
|
||||
alpha = min_range * cardinality / total_range
|
||||
assert alpha <= 1, "you must reduce min_range"
|
||||
ranges = (((1 - alpha) * total_range) * pdf).floor().long()
|
||||
ranges += min_range
|
||||
quantized_cdf = torch.cumsum(ranges, dim=-1)
|
||||
if min_range < 2:
|
||||
raise ValueError("min_range must be at least 2.")
|
||||
if check:
|
||||
assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
|
||||
if (
|
||||
(quantized_cdf[1:] - quantized_cdf[:-1]) < min_range
|
||||
).any() or quantized_cdf[0] < min_range:
|
||||
raise ValueError("You must increase your total_range_bits.")
|
||||
return quantized_cdf
|
||||
|
||||
|
||||
class ArithmeticCoder:
|
||||
"""ArithmeticCoder,
|
||||
Let us take a distribution `p` over `N` symbols, and assume we have a stream
|
||||
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
|
||||
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
|
||||
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
|
||||
sequence `(s_t)` by doing the following:
|
||||
|
||||
1) Initialize the current range to` [0 ** 2 B - 1]`.
|
||||
2) For each time step t, split the current range into contiguous chunks,
|
||||
one for each possible outcome, with size roughly proportional to `p`.
|
||||
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
|
||||
would be `{[0, 2], [3, 3]}`.
|
||||
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
|
||||
4) When done encoding all the values, just select any value remaining in the range.
|
||||
|
||||
You will notice that this procedure can fail: for instance if at any point in time
|
||||
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
|
||||
possible outcome. Intuitively, the more likely a value is, the less the range width
|
||||
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
|
||||
coding scheme, likely outcomes would take less bits, and more of them can be coded
|
||||
with a fixed budget.
|
||||
|
||||
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
|
||||
when the current range decreases below a given limit (given by `total_range_bits`), without
|
||||
having to redo all the computations. If we encode mostly likely values, we will seldom
|
||||
need to inject new bits, but a single rare value can deplete our stock of entropy!
|
||||
|
||||
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
|
||||
code works for any sequence `(p_t)` possibly different for each timestep.
|
||||
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
|
||||
the KL between the true distribution and `p_t`, the most efficient the coding will be.
|
||||
|
||||
Args:
|
||||
fo (IO[bytes]): file-like object to which the bytes will be written to.
|
||||
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
|
||||
Any time the current range width fall under this limit, new bits will
|
||||
be injected to rescale the initial range.
|
||||
"""
|
||||
|
||||
def __init__(self, fo: IO[bytes], total_range_bits: int = 24):
|
||||
assert total_range_bits <= 30
|
||||
self.total_range_bits = total_range_bits
|
||||
self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
|
||||
self.low: int = 0
|
||||
self.high: int = 0
|
||||
self.max_bit: int = -1
|
||||
self._dbg: List[Any] = []
|
||||
self._dbg2: List[Any] = []
|
||||
|
||||
@property
|
||||
def delta(self) -> int:
|
||||
"""Return the current range width."""
|
||||
return self.high - self.low + 1
|
||||
|
||||
def _flush_common_prefix(self):
|
||||
# If self.low and self.high start with the sames bits,
|
||||
# those won't change anymore as we always just increase the range
|
||||
# by powers of 2, and we can flush them out to the bit stream.
|
||||
assert self.high >= self.low, (self.low, self.high)
|
||||
assert self.high < 2 ** (self.max_bit + 1)
|
||||
while self.max_bit >= 0:
|
||||
b1 = self.low >> self.max_bit
|
||||
b2 = self.high >> self.max_bit
|
||||
if b1 == b2:
|
||||
self.low -= b1 << self.max_bit
|
||||
self.high -= b1 << self.max_bit
|
||||
assert self.high >= self.low, (self.high, self.low, self.max_bit)
|
||||
assert self.low >= 0
|
||||
self.max_bit -= 1
|
||||
self.packer.push(b1)
|
||||
else:
|
||||
break
|
||||
|
||||
def push(self, symbol: int, quantized_cdf: Tensor):
|
||||
"""Push the given symbol on the stream, flushing out bits
|
||||
if possible.
|
||||
|
||||
Args:
|
||||
symbol (int): symbol to encode with the AC.
|
||||
quantized_cdf (Tensor): use `build_stable_quantized_cdf`
|
||||
to build this from your pdf estimate.
|
||||
"""
|
||||
while self.delta < 2**self.total_range_bits:
|
||||
self.low *= 2
|
||||
self.high = self.high * 2 + 1
|
||||
self.max_bit += 1
|
||||
|
||||
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
|
||||
range_high = quantized_cdf[symbol].item() - 1
|
||||
effective_low = int(
|
||||
math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
|
||||
)
|
||||
effective_high = int(
|
||||
math.floor(range_high * (self.delta / (2**self.total_range_bits)))
|
||||
)
|
||||
assert self.low <= self.high
|
||||
self.high = self.low + effective_high
|
||||
self.low = self.low + effective_low
|
||||
assert self.low <= self.high, (
|
||||
effective_low,
|
||||
effective_high,
|
||||
range_low,
|
||||
range_high,
|
||||
)
|
||||
self._dbg.append((self.low, self.high))
|
||||
self._dbg2.append((self.low, self.high))
|
||||
outs = self._flush_common_prefix()
|
||||
assert self.low <= self.high
|
||||
assert self.max_bit >= -1
|
||||
assert self.max_bit <= 61, self.max_bit
|
||||
return outs
|
||||
|
||||
def flush(self):
|
||||
"""Flush the remaining information to the stream."""
|
||||
while self.max_bit >= 0:
|
||||
b1 = (self.low >> self.max_bit) & 1
|
||||
self.packer.push(b1)
|
||||
self.max_bit -= 1
|
||||
self.packer.flush()
|
||||
|
||||
|
||||
class ArithmeticDecoder:
|
||||
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
|
||||
|
||||
Note that this must be called with **exactly** the same parameters and sequence
|
||||
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
|
||||
|
||||
If the AC encoder current range is [L, H], with `L` and `H` having the some common
|
||||
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
|
||||
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
|
||||
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
|
||||
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
|
||||
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
|
||||
and we will need to read new bits from the stream and repeat the process.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, fo: IO[bytes], total_range_bits: int = 24):
|
||||
self.total_range_bits = total_range_bits
|
||||
self.low: int = 0
|
||||
self.high: int = 0
|
||||
self.current: int = 0
|
||||
self.max_bit: int = -1
|
||||
self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
|
||||
# Following is for debugging
|
||||
self._dbg: List[Any] = []
|
||||
self._dbg2: List[Any] = []
|
||||
self._last: Any = None
|
||||
|
||||
@property
|
||||
def delta(self) -> int:
|
||||
return self.high - self.low + 1
|
||||
|
||||
def _flush_common_prefix(self):
|
||||
# Given the current range [L, H], if both have a common prefix,
|
||||
# we know we can remove it from our representation to avoid handling large numbers.
|
||||
while self.max_bit >= 0:
|
||||
b1 = self.low >> self.max_bit
|
||||
b2 = self.high >> self.max_bit
|
||||
if b1 == b2:
|
||||
self.low -= b1 << self.max_bit
|
||||
self.high -= b1 << self.max_bit
|
||||
self.current -= b1 << self.max_bit
|
||||
assert self.high >= self.low
|
||||
assert self.low >= 0
|
||||
self.max_bit -= 1
|
||||
else:
|
||||
break
|
||||
|
||||
def pull(self, quantized_cdf: Tensor) -> Optional[int]:
|
||||
"""Pull a symbol, reading as many bits from the stream as required.
|
||||
This returns `None` when the stream has been exhausted.
|
||||
|
||||
Args:
|
||||
quantized_cdf (Tensor): use `build_stable_quantized_cdf`
|
||||
to build this from your pdf estimate. This must be **exatly**
|
||||
the same cdf as the one used at encoding time.
|
||||
"""
|
||||
while self.delta < 2**self.total_range_bits:
|
||||
bit = self.unpacker.pull()
|
||||
if bit is None:
|
||||
return None
|
||||
self.low *= 2
|
||||
self.high = self.high * 2 + 1
|
||||
self.current = self.current * 2 + bit
|
||||
self.max_bit += 1
|
||||
|
||||
def bin_search(low_idx: int, high_idx: int):
|
||||
# Binary search is not just for coding interviews :)
|
||||
if high_idx < low_idx:
|
||||
raise RuntimeError("Binary search failed")
|
||||
mid = (low_idx + high_idx) // 2
|
||||
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
|
||||
range_high = quantized_cdf[mid].item() - 1
|
||||
effective_low = int(
|
||||
math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
|
||||
)
|
||||
effective_high = int(
|
||||
math.floor(range_high * (self.delta / (2**self.total_range_bits)))
|
||||
)
|
||||
low = effective_low + self.low
|
||||
high = effective_high + self.low
|
||||
if self.current >= low:
|
||||
if self.current <= high:
|
||||
return (mid, low, high, self.current)
|
||||
else:
|
||||
return bin_search(mid + 1, high_idx)
|
||||
else:
|
||||
return bin_search(low_idx, mid - 1)
|
||||
|
||||
self._last = (self.low, self.high, self.current, self.max_bit)
|
||||
sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
|
||||
self._dbg.append((self.low, self.high, self.current))
|
||||
self._flush_common_prefix()
|
||||
self._dbg2.append((self.low, self.high, self.current))
|
||||
|
||||
return sym
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(1234)
|
||||
random.seed(1234)
|
||||
for _ in range(4):
|
||||
pdfs = []
|
||||
cardinality = random.randrange(4000)
|
||||
steps = random.randrange(100, 500)
|
||||
fo = io.BytesIO()
|
||||
encoder = ArithmeticCoder(fo)
|
||||
symbols = []
|
||||
for step in range(steps):
|
||||
pdf = torch.softmax(torch.randn(cardinality), dim=0)
|
||||
pdfs.append(pdf)
|
||||
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
||||
symbol = torch.multinomial(pdf, 1).item()
|
||||
symbols.append(symbol)
|
||||
encoder.push(symbol, q_cdf)
|
||||
encoder.flush()
|
||||
|
||||
fo.seek(0)
|
||||
decoder = ArithmeticDecoder(fo)
|
||||
for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
|
||||
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
||||
decoded_symbol = decoder.pull(q_cdf)
|
||||
assert decoded_symbol == symbol, idx
|
||||
assert decoder.pull(torch.zeros(1)) is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
377
egs/libritts/CODEC/encodec/quantization/core_vq.py
Normal file
377
egs/libritts/CODEC/encodec/quantization/core_vq.py
Normal file
@ -0,0 +1,377 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# This implementation is inspired from
|
||||
# https://github.com/lucidrains/vector-quantize-pytorch
|
||||
# which is released under MIT License. Hereafter, the original license:
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2020 Phil Wang
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
"""Core vector quantization implementation."""
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
from .distrib import broadcast_tensors
|
||||
|
||||
|
||||
def default(val: Any, d: Any) -> Any:
|
||||
return val if val is not None else d
|
||||
|
||||
|
||||
def ema_inplace(moving_avg, new, decay: float):
|
||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||
|
||||
|
||||
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
||||
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
||||
|
||||
|
||||
def uniform_init(*shape: int):
|
||||
t = torch.empty(shape)
|
||||
nn.init.kaiming_uniform_(t)
|
||||
return t
|
||||
|
||||
|
||||
def sample_vectors(samples, num: int):
|
||||
num_samples, device = samples.shape[0], samples.device
|
||||
|
||||
if num_samples >= num:
|
||||
indices = torch.randperm(num_samples, device=device)[:num]
|
||||
else:
|
||||
indices = torch.randint(0, num_samples, (num,), device=device)
|
||||
|
||||
return samples[indices]
|
||||
|
||||
|
||||
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
||||
dim, dtype = samples.shape[-1], samples.dtype
|
||||
|
||||
means = sample_vectors(samples, num_clusters)
|
||||
|
||||
for _ in range(num_iters):
|
||||
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
||||
dists = -(diffs**2).sum(dim=-1)
|
||||
|
||||
buckets = dists.max(dim=-1).indices
|
||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||
zero_mask = bins == 0
|
||||
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
||||
|
||||
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
||||
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
||||
new_means = new_means / bins_min_clamped[..., None]
|
||||
|
||||
means = torch.where(zero_mask[..., None], means, new_means)
|
||||
|
||||
return means, bins
|
||||
|
||||
|
||||
class EuclideanCodebook(nn.Module):
|
||||
"""Codebook with Euclidean distance.
|
||||
Args:
|
||||
dim (int): Dimension.
|
||||
codebook_size (int): Codebook size.
|
||||
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
||||
If set to true, run the k-means algorithm on the first training batch and use
|
||||
the learned centroids as initialization.
|
||||
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
epsilon (float): Epsilon value for numerical stability.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
kmeans_init: int = False,
|
||||
kmeans_iters: int = 10,
|
||||
decay: float = 0.99,
|
||||
epsilon: float = 1e-5,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
init_fn: Union[Callable[..., torch.Tensor], Any] = (
|
||||
uniform_init if not kmeans_init else torch.zeros
|
||||
)
|
||||
embed = init_fn(codebook_size, dim)
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
|
||||
self.kmeans_iters = kmeans_iters
|
||||
self.epsilon = epsilon
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
|
||||
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
||||
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
||||
self.register_buffer("embed", embed)
|
||||
self.register_buffer("embed_avg", embed.clone())
|
||||
|
||||
@torch.jit.ignore
|
||||
def init_embed_(self, data):
|
||||
if self.inited:
|
||||
return
|
||||
|
||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
||||
self.embed.data.copy_(embed)
|
||||
self.embed_avg.data.copy_(embed.clone())
|
||||
self.cluster_size.data.copy_(cluster_size)
|
||||
self.inited.data.copy_(torch.Tensor([True]))
|
||||
# Make sure all buffers across workers are in sync after initialization
|
||||
broadcast_tensors(self.buffers())
|
||||
|
||||
def replace_(self, samples, mask):
|
||||
modified_codebook = torch.where(
|
||||
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
||||
)
|
||||
self.embed.data.copy_(modified_codebook)
|
||||
|
||||
def expire_codes_(self, batch_samples):
|
||||
if self.threshold_ema_dead_code == 0:
|
||||
return
|
||||
|
||||
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
||||
if not torch.any(expired_codes):
|
||||
return
|
||||
|
||||
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
||||
self.replace_(batch_samples, mask=expired_codes)
|
||||
broadcast_tensors(self.buffers())
|
||||
|
||||
def preprocess(self, x):
|
||||
x = rearrange(x, "... d -> (...) d")
|
||||
return x
|
||||
|
||||
def quantize(self, x):
|
||||
embed = self.embed.t()
|
||||
dist = -(
|
||||
x.pow(2).sum(1, keepdim=True)
|
||||
- 2 * x @ embed
|
||||
+ embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
return embed_ind
|
||||
|
||||
def postprocess_emb(self, embed_ind, shape):
|
||||
return embed_ind.view(*shape[:-1])
|
||||
|
||||
def dequantize(self, embed_ind):
|
||||
quantize = F.embedding(embed_ind, self.embed)
|
||||
return quantize
|
||||
|
||||
def encode(self, x):
|
||||
shape = x.shape
|
||||
# pre-process
|
||||
x = self.preprocess(x)
|
||||
# quantize
|
||||
embed_ind = self.quantize(x)
|
||||
# post-process
|
||||
embed_ind = self.postprocess_emb(embed_ind, shape)
|
||||
return embed_ind
|
||||
|
||||
def decode(self, embed_ind):
|
||||
quantize = self.dequantize(embed_ind)
|
||||
return quantize
|
||||
|
||||
def forward(self, x):
|
||||
shape, dtype = x.shape, x.dtype
|
||||
x = self.preprocess(x)
|
||||
|
||||
self.init_embed_(x)
|
||||
|
||||
embed_ind = self.quantize(x)
|
||||
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
||||
embed_ind = self.postprocess_emb(embed_ind, shape)
|
||||
quantize = self.dequantize(embed_ind)
|
||||
|
||||
if self.training:
|
||||
# We do the expiry of code at that point as buffers are in sync
|
||||
# and all the workers will take the same decision.
|
||||
self.expire_codes_(x)
|
||||
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
||||
embed_sum = x.t() @ embed_onehot
|
||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
||||
cluster_size = (
|
||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
||||
* self.cluster_size.sum()
|
||||
)
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
|
||||
return quantize, embed_ind
|
||||
|
||||
|
||||
class VectorQuantization(nn.Module):
|
||||
"""Vector quantization implementation.
|
||||
Currently supports only euclidean distance.
|
||||
Args:
|
||||
dim (int): Dimension
|
||||
codebook_size (int): Codebook size
|
||||
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
epsilon (float): Epsilon value for numerical stability.
|
||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
commitment_weight (float): Weight for commitment loss.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
codebook_dim: Optional[int] = None,
|
||||
decay: float = 0.99,
|
||||
epsilon: float = 1e-5,
|
||||
kmeans_init: bool = True,
|
||||
kmeans_iters: int = 50,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
commitment_weight: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
_codebook_dim: int = default(codebook_dim, dim)
|
||||
|
||||
requires_projection = _codebook_dim != dim
|
||||
self.project_in = (
|
||||
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.commitment_weight = commitment_weight
|
||||
|
||||
self._codebook = EuclideanCodebook(
|
||||
dim=_codebook_dim,
|
||||
codebook_size=codebook_size,
|
||||
kmeans_init=kmeans_init,
|
||||
kmeans_iters=kmeans_iters,
|
||||
decay=decay,
|
||||
epsilon=epsilon,
|
||||
threshold_ema_dead_code=threshold_ema_dead_code,
|
||||
)
|
||||
self.codebook_size = codebook_size
|
||||
|
||||
@property
|
||||
def codebook(self):
|
||||
return self._codebook.embed
|
||||
|
||||
def encode(self, x):
|
||||
x = rearrange(x, "b d n -> b n d")
|
||||
x = self.project_in(x)
|
||||
embed_in = self._codebook.encode(x)
|
||||
return embed_in
|
||||
|
||||
def decode(self, embed_ind):
|
||||
quantize = self._codebook.decode(embed_ind)
|
||||
quantize = self.project_out(quantize)
|
||||
quantize = rearrange(quantize, "b n d -> b d n")
|
||||
return quantize
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
x = rearrange(x, "b d n -> b n d")
|
||||
x = self.project_in(x)
|
||||
|
||||
quantize, embed_ind = self._codebook(x)
|
||||
|
||||
if self.training:
|
||||
quantize = x + (quantize - x).detach()
|
||||
|
||||
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
||||
|
||||
if self.training:
|
||||
if self.commitment_weight > 0:
|
||||
commit_loss = F.mse_loss(quantize.detach(), x)
|
||||
loss = loss + commit_loss * self.commitment_weight
|
||||
|
||||
quantize = self.project_out(quantize)
|
||||
quantize = rearrange(quantize, "b n d -> b d n")
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
|
||||
class ResidualVectorQuantization(nn.Module):
|
||||
"""Residual vector quantization implementation.
|
||||
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, *, num_quantizers, **kwargs):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
||||
)
|
||||
|
||||
def forward(self, x, n_q: Optional[int] = None):
|
||||
quantized_out = 0.0
|
||||
residual = x
|
||||
|
||||
all_losses = []
|
||||
all_indices = []
|
||||
|
||||
n_q = n_q or len(self.layers)
|
||||
|
||||
for layer in self.layers[:n_q]:
|
||||
quantized, indices, loss = layer(residual)
|
||||
residual = residual - quantized
|
||||
quantized_out = quantized_out + quantized
|
||||
|
||||
all_indices.append(indices)
|
||||
all_losses.append(loss)
|
||||
|
||||
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
||||
return quantized_out, out_indices, out_losses
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, n_q: Optional[int] = None, st: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
residual = x
|
||||
all_indices = []
|
||||
n_q = n_q or len(self.layers)
|
||||
st = st or 0
|
||||
for layer in self.layers[st:n_q]: # 设置解码的起止layer
|
||||
indices = layer.encode(residual)
|
||||
quantized = layer.decode(indices)
|
||||
residual = residual - quantized
|
||||
all_indices.append(indices)
|
||||
out_indices = torch.stack(all_indices)
|
||||
return out_indices
|
||||
|
||||
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
||||
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
||||
for i, indices in enumerate(q_indices):
|
||||
layer = self.layers[i]
|
||||
quantized = layer.decode(indices)
|
||||
quantized_out = quantized_out + quantized
|
||||
return quantized_out
|
126
egs/libritts/CODEC/encodec/quantization/distrib.py
Normal file
126
egs/libritts/CODEC/encodec/quantization/distrib.py
Normal file
@ -0,0 +1,126 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Torch distributed utilities."""
|
||||
from typing import Dict, Iterable, List
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
|
||||
def rank():
|
||||
if dist.is_initialized():
|
||||
return dist.get_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def world_size():
|
||||
if dist.is_initialized():
|
||||
return dist.get_world_size()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def is_distributed():
|
||||
return world_size() > 1
|
||||
|
||||
|
||||
def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM):
|
||||
if is_distributed():
|
||||
return dist.all_reduce(tensor, op)
|
||||
|
||||
|
||||
def _is_complex_or_float(tensor):
|
||||
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
||||
|
||||
|
||||
def _check_number_of_params(params: List[torch.Tensor]):
|
||||
# utility function to check that the number of params in all workers is the same,
|
||||
# and thus avoid a deadlock with distributed all reduce.
|
||||
if not is_distributed() or not params:
|
||||
return
|
||||
# print('params[0].device ', params[0].device)
|
||||
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
||||
all_reduce(tensor)
|
||||
if tensor.item() != len(params) * world_size():
|
||||
# If not all the workers have the same number, for at least one of them,
|
||||
# this inequality will be verified.
|
||||
raise RuntimeError(
|
||||
f"Mismatch in number of params: ours is {len(params)}, "
|
||||
"at least one worker has a different one."
|
||||
)
|
||||
|
||||
|
||||
def broadcast_tensors(tensors: Iterable[torch.Tensor], src: int = 0):
|
||||
"""Broadcast the tensors from the given parameters to all workers.
|
||||
This can be used to ensure that all workers have the same model to start with.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return
|
||||
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
||||
_check_number_of_params(tensors)
|
||||
handles = []
|
||||
for tensor in tensors:
|
||||
# src = int(rank()) # added code
|
||||
handle = dist.broadcast(tensor.data, src=src, async_op=True)
|
||||
handles.append(handle)
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
|
||||
def sync_buffer(buffers, average=True):
|
||||
"""
|
||||
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return
|
||||
handles = []
|
||||
for buffer in buffers:
|
||||
if torch.is_floating_point(buffer.data):
|
||||
if average:
|
||||
handle = dist.all_reduce(
|
||||
buffer.data, op=dist.ReduceOp.SUM, async_op=True
|
||||
)
|
||||
else:
|
||||
handle = dist.broadcast(buffer.data, src=0, async_op=True)
|
||||
handles.append((buffer, handle))
|
||||
for buffer, handle in handles:
|
||||
handle.wait()
|
||||
if average:
|
||||
buffer.data /= world_size
|
||||
|
||||
|
||||
def sync_grad(params):
|
||||
"""
|
||||
Simpler alternative to DistributedDataParallel, that doesn't rely
|
||||
on any black magic. For simple models it can also be as fast.
|
||||
Just call this on your model parameters after the call to backward!
|
||||
"""
|
||||
if not is_distributed():
|
||||
return
|
||||
handles = []
|
||||
for p in params:
|
||||
if p.grad is not None:
|
||||
handle = dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM, async_op=True)
|
||||
handles.append((p, handle))
|
||||
for p, handle in handles:
|
||||
handle.wait()
|
||||
p.grad.data /= world_size()
|
||||
|
||||
|
||||
def average_metrics(metrics: Dict[str, float], count=1.0):
|
||||
"""Average a dictionary of metrics across all workers, using the optional
|
||||
`count` as unormalized weight.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return metrics
|
||||
keys, values = zip(*metrics.items())
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
||||
tensor *= count
|
||||
all_reduce(tensor)
|
||||
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
||||
return dict(zip(keys, averaged))
|
121
egs/libritts/CODEC/encodec/quantization/vq.py
Normal file
121
egs/libritts/CODEC/encodec/quantization/vq.py
Normal file
@ -0,0 +1,121 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Residual vector quantizer implementation."""
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .core_vq import ResidualVectorQuantization
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizedResult:
|
||||
quantized: Tensor
|
||||
codes: Tensor
|
||||
bandwidth: Tensor # bandwidth in kb/s used, per batch item.
|
||||
penalty: Optional[Tensor] = None
|
||||
metrics: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class ResidualVectorQuantizer(nn.Module):
|
||||
"""Residual Vector Quantizer.
|
||||
Args:
|
||||
dimension (int): Dimension of the codebooks.
|
||||
n_q (int): Number of residual vector quantizers used.
|
||||
bins (int): Codebook size.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dimension: int = 256,
|
||||
n_q: int = 8,
|
||||
bins: int = 1024,
|
||||
decay: float = 0.99,
|
||||
kmeans_init: bool = True,
|
||||
kmeans_iters: int = 50,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_q = n_q
|
||||
self.dimension = dimension
|
||||
self.bins = bins
|
||||
self.decay = decay
|
||||
self.kmeans_init = kmeans_init
|
||||
self.kmeans_iters = kmeans_iters
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
self.vq = ResidualVectorQuantization(
|
||||
dim=self.dimension,
|
||||
codebook_size=self.bins,
|
||||
num_quantizers=self.n_q,
|
||||
decay=self.decay,
|
||||
kmeans_init=self.kmeans_init,
|
||||
kmeans_iters=self.kmeans_iters,
|
||||
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, sample_rate: int, bandwidth: Optional[float] = None
|
||||
) -> QuantizedResult:
|
||||
"""Residual vector quantization on the given input tensor.
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
sample_rate (int): Sample rate of the input tensor.
|
||||
bandwidth (float): Target bandwidth.
|
||||
Returns:
|
||||
QuantizedResult:
|
||||
The quantized (or approximately quantized) representation with
|
||||
the associated bandwidth and any penalty term for the loss.
|
||||
"""
|
||||
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
|
||||
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
|
||||
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
|
||||
bw = torch.tensor(n_q * bw_per_q).to(x)
|
||||
return quantized, codes, bw, torch.mean(commit_loss)
|
||||
# return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
|
||||
|
||||
def get_num_quantizers_for_bandwidth(
|
||||
self, sample_rate: int, bandwidth: Optional[float] = None
|
||||
) -> int:
|
||||
"""Return n_q based on specified target bandwidth."""
|
||||
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
|
||||
n_q = self.n_q
|
||||
if bandwidth and bandwidth > 0.0:
|
||||
n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
|
||||
return n_q
|
||||
|
||||
def get_bandwidth_per_quantizer(self, sample_rate: int):
|
||||
"""Return bandwidth per quantizer for a given input sample rate."""
|
||||
return math.log2(self.bins) * sample_rate / 1000
|
||||
|
||||
def encode(
|
||||
self,
|
||||
x: Tensor,
|
||||
sample_rate: int,
|
||||
bandwidth: Optional[float] = None,
|
||||
st: Optional[int] = None,
|
||||
) -> Tensor:
|
||||
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
||||
The RVQ encode method sets the appropriate number of quantizer to use
|
||||
and returns indices for each quantizer.
|
||||
"""
|
||||
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
|
||||
st = st or 0
|
||||
codes = self.vq.encode(x, n_q=n_q, st=st)
|
||||
return codes
|
||||
|
||||
def decode(self, codes: Tensor) -> Tensor:
|
||||
"""Decode the given codes to the quantized representation."""
|
||||
quantized = self.vq.decode(codes)
|
||||
return quantized
|
902
egs/libritts/CODEC/encodec/train.py
Normal file
902
egs/libritts/CODEC/encodec/train.py
Normal file
@ -0,0 +1,902 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from encodec import Encodec
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--master-port",
|
||||
type=int,
|
||||
default=12354,
|
||||
help="Master port to use for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tensorboard",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Resume training from this epoch. It should be positive.
|
||||
If larger than 1, it will load checkpoint from
|
||||
exp-dir/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="vits/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr", type=float, default=3.0e-4, help="The base learning rate."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--print-diagnostics",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Accumulate stats on activations, print them and exit.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--inf-check",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Add hooks to check for infinite module outputs and gradients.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=20,
|
||||
help="""Save checkpoint after processing this number of epochs"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.cur_epoch % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
|
||||
Since it will take around 1000 epochs, we suggest using a large
|
||||
save_every_n to save disk space.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters.
|
||||
|
||||
All training related parameters that are not passed from the commandline
|
||||
are saved in the variable `params`.
|
||||
|
||||
Commandline options are merged into `params` after they are parsed, so
|
||||
you can also access them via `params`.
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_valid_loss: Best validation loss so far. It is used to select
|
||||
the model that has the lowest validation loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_train_epoch: It is the epoch that has the best training loss.
|
||||
|
||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||
|
||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||
contains number of batches trained so far across
|
||||
epochs.
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||
|
||||
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
# training params
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": -1, # 0
|
||||
"log_interval": 50,
|
||||
"valid_interval": 200,
|
||||
"env_info": get_env_info(),
|
||||
"sampling_rate": 24000,
|
||||
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
|
||||
"lambda_wav": 100.0, # loss scaling coefficient for waveform loss
|
||||
"lambda_feat": 1.0, # loss scaling coefficient for feat loss
|
||||
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
|
||||
"lambda_com": 1000.0, # loss scaling coefficient for commitment loss
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict, model: nn.Module
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is larger than 1, it will load the checkpoint from
|
||||
`params.start_epoch - 1`.
|
||||
|
||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
Returns:
|
||||
Return a dict containing previously saved training info.
|
||||
"""
|
||||
if params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
return None
|
||||
|
||||
assert filename.is_file(), f"{filename} does not exist!"
|
||||
|
||||
saved_params = load_checkpoint(filename, model=model)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def get_model(params: AttributeDict) -> nn.Module:
|
||||
"""Get the model based on the configuration."""
|
||||
|
||||
from discriminators import (
|
||||
MultiPeriodDiscriminator,
|
||||
MultiScaleDiscriminator,
|
||||
MultiScaleSTFTDiscriminator,
|
||||
)
|
||||
from modules.seanet import SEANetDecoder, SEANetEncoder
|
||||
from quantization import ResidualVectorQuantizer
|
||||
|
||||
generator_params = {
|
||||
"generator_n_filters": 32,
|
||||
"dimension": 512,
|
||||
"ratios": [2, 2, 2, 4],
|
||||
"target_bandwidths": [7.5, 15],
|
||||
"bins": 1024,
|
||||
}
|
||||
discriminator_params = {
|
||||
"stft_discriminator_n_filters": 32,
|
||||
}
|
||||
|
||||
params.update(generator_params)
|
||||
params.update(discriminator_params)
|
||||
|
||||
hop_length = np.prod(params.ratios)
|
||||
n_q = int(
|
||||
1000
|
||||
* params.target_bandwidths[-1]
|
||||
// (math.ceil(params.sample_rate / hop_length) * 10)
|
||||
)
|
||||
|
||||
encoder = SEANetEncoder(
|
||||
n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios
|
||||
)
|
||||
decoder = SEANetDecoder(
|
||||
n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios
|
||||
)
|
||||
quantizer = ResidualVectorQuantizer(
|
||||
dimension=params.dimension, n_q=n_q, bins=params.bins
|
||||
)
|
||||
|
||||
model = Encodec(
|
||||
params=params,
|
||||
sample_rate=params.sampling_rate,
|
||||
target_bandwidths=params.target_bandwidths,
|
||||
encoder=encoder,
|
||||
quantizer=quantizer,
|
||||
decoder=decoder,
|
||||
multi_scale_discriminator=MultiScaleDiscriminator(),
|
||||
multi_period_discriminator=MultiPeriodDiscriminator(),
|
||||
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(),
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def prepare_input(
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Parse batch data"""
|
||||
audio = batch["audio"].to(device, memory_format=torch.contiguous_format)
|
||||
features = batch["features"].to(device, memory_format=torch.contiguous_format)
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
|
||||
return audio, audio_lens, features, features_lens
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
optimizer_g: Optimizer,
|
||||
optimizer_d: Optimizer,
|
||||
scheduler_g: LRSchedulerType,
|
||||
scheduler_d: LRSchedulerType,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model to be trained.
|
||||
optimizer_g:
|
||||
The optimizer for generator.
|
||||
optimizer_d:
|
||||
The optimizer for discriminator.
|
||||
scheduler_g:
|
||||
The learning rate scheduler for generator, we call step() every epoch.
|
||||
scheduler_d:
|
||||
The learning rate scheduler for discriminator, we call step() every epoch.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
rank:
|
||||
The rank of the node in DDP training. If no DDP is used, it should
|
||||
be set to 0.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
|
||||
# used to summary the stats over iterations in one epoch
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
saved_bad_model = False
|
||||
|
||||
def save_bad_model(suffix: str = ""):
|
||||
save_checkpoint(
|
||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
|
||||
batch_size = len(batch["tokens"])
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
_,
|
||||
_,
|
||||
) = prepare_input(batch, device)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
try:
|
||||
with autocast(enabled=params.use_fp16):
|
||||
# forward discriminator
|
||||
loss_d, stats_d = model(
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
global_step=params.batch_idx_train,
|
||||
return_sample=False,
|
||||
forward_generator=False,
|
||||
)
|
||||
for k, v in stats_d.items():
|
||||
loss_info[k] = v * batch_size
|
||||
# update discriminator
|
||||
optimizer_d.zero_grad()
|
||||
scaler.scale(loss_d).backward()
|
||||
scaler.step(optimizer_d)
|
||||
|
||||
with autocast(enabled=params.use_fp16):
|
||||
# forward generator
|
||||
loss_g, stats_g = model(
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
global_step=params.batch_idx_train,
|
||||
forward_generator=True,
|
||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||
)
|
||||
for k, v in stats_g.items():
|
||||
if "returned_sample" not in k:
|
||||
loss_info[k] = v * batch_size
|
||||
# update generator
|
||||
optimizer_g.zero_grad()
|
||||
scaler.scale(loss_g).backward()
|
||||
scaler.step(optimizer_g)
|
||||
scaler.update()
|
||||
|
||||
# summary stats
|
||||
tot_loss = tot_loss + loss_info
|
||||
except: # noqa
|
||||
save_bad_model()
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
|
||||
if cur_grad_scale < 8.0 or (
|
||||
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
|
||||
):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
if not saved_bad_model:
|
||||
save_bad_model(suffix="-first-warning")
|
||||
saved_bad_model = True
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
|
||||
if params.batch_idx_train % params.log_interval == 0:
|
||||
cur_lr_g = max(scheduler_g.get_last_lr())
|
||||
cur_lr_d = max(scheduler_d.get_last_lr())
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
|
||||
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
||||
f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate_g", cur_lr_g, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate_d", cur_lr_d, params.batch_idx_train
|
||||
)
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
if "returned_sample" in stats_g:
|
||||
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
||||
tb_writer.add_audio(
|
||||
"train/speech_hat_",
|
||||
speech_hat_,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/speech_",
|
||||
speech_,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_image(
|
||||
"train/mel_hat_",
|
||||
plot_feature(mel_hat_),
|
||||
params.batch_idx_train,
|
||||
dataformats="HWC",
|
||||
)
|
||||
tb_writer.add_image(
|
||||
"train/mel_",
|
||||
plot_feature(mel_),
|
||||
params.batch_idx_train,
|
||||
dataformats="HWC",
|
||||
)
|
||||
|
||||
if (
|
||||
params.batch_idx_train % params.valid_interval == 0
|
||||
and not params.print_diagnostics
|
||||
):
|
||||
logging.info("Computing validation loss")
|
||||
valid_info, (speech_hat, speech) = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valdi_speech_hat",
|
||||
speech_hat,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valdi_speech",
|
||||
speech,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
|
||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
|
||||
# used to summary the stats over iterations
|
||||
tot_loss = MetricsTracker()
|
||||
returned_sample = None
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
batch_size = len(batch["tokens"])
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
_,
|
||||
_,
|
||||
) = prepare_input(batch, device)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
# forward discriminator
|
||||
loss_d, stats_d = model(
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
global_step=params.batch_idx_train,
|
||||
return_sample=False,
|
||||
forward_generator=False,
|
||||
)
|
||||
assert loss_d.requires_grad is False
|
||||
for k, v in stats_d.items():
|
||||
loss_info[k] = v * batch_size
|
||||
|
||||
# forward generator
|
||||
loss_g, stats_g = model(
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
global_step=params.batch_idx_train,
|
||||
forward_generator=True,
|
||||
return_sample=batch_idx == 0,
|
||||
)
|
||||
assert loss_g.requires_grad is False
|
||||
for k, v in stats_g.items():
|
||||
loss_info[k] = v * batch_size
|
||||
|
||||
# summary stats
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
# infer for first batch:
|
||||
if batch_idx == 0 and rank == 0:
|
||||
speech_hat_, speech_, _, _ = stats_g["returned_sample"]
|
||||
|
||||
returned_sample = (speech_hat_, speech_)
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(device)
|
||||
|
||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss, returned_sample
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
model: Union[nn.Module, DDP],
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
optimizer_g: torch.optim.Optimizer,
|
||||
optimizer_d: torch.optim.Optimizer,
|
||||
params: AttributeDict,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
logging.info(
|
||||
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
||||
)
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
_,
|
||||
_,
|
||||
) = prepare_input(batch, device)
|
||||
try:
|
||||
# for discriminator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
loss_d, stats_d = model(
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
global_step=params.batch_idx_train,
|
||||
return_sample=False,
|
||||
forward_generator=False,
|
||||
)
|
||||
optimizer_d.zero_grad()
|
||||
loss_d.backward()
|
||||
# for generator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
loss_g, stats_g = model(
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=True,
|
||||
global_step=params.batch_idx_train,
|
||||
return_sample=False,
|
||||
)
|
||||
optimizer_g.zero_grad()
|
||||
loss_g.backward()
|
||||
except Exception as e:
|
||||
if "CUDA out of memory" in str(e):
|
||||
logging.error(
|
||||
"Your GPU ran out of memory with the current "
|
||||
"max_duration setting. We recommend decreasing "
|
||||
"max_duration and trying again.\n"
|
||||
f"Failing criterion: {criterion} "
|
||||
f"(={crit_values[criterion]}) ..."
|
||||
)
|
||||
raise
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
vctk = VctkTtsDataModule(args)
|
||||
|
||||
train_cuts = vctk.train_cuts()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
encoder = model.encoder
|
||||
decoder = model.decoder
|
||||
quantizer = model.quantizer
|
||||
multi_scale_discriminator = model.multi_scale_discriminator
|
||||
multi_period_discriminator = model.multi_period_discriminator
|
||||
multi_scale_stft_discriminator = model.multi_scale_stft_discriminator
|
||||
|
||||
num_param_e = sum([p.numel() for p in encoder.parameters()])
|
||||
logging.info(f"Number of parameters in encoder: {num_param_e}")
|
||||
num_param_d = sum([p.numel() for p in decoder.parameters()])
|
||||
logging.info(f"Number of parameters in decoder: {num_param_d}")
|
||||
num_param_q = sum([p.numel() for p in quantizer.parameters()])
|
||||
logging.info(f"Number of parameters in quantizer: {num_param_q}")
|
||||
num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()])
|
||||
logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}")
|
||||
num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()])
|
||||
logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}")
|
||||
num_param_dstft = sum(
|
||||
[p.numel() for p in multi_scale_stft_discriminator.parameters()]
|
||||
)
|
||||
logging.info(
|
||||
f"Number of parameters in multi_scale_stft_discriminator: {num_param_dstft}"
|
||||
)
|
||||
logging.info(
|
||||
f"Total number of parameters: {num_param_e + num_param_d + num_param_q + num_param_ds + num_param_dp + num_param_dstft}"
|
||||
)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer_g = torch.optim.AdamW(
|
||||
itertools.chain(
|
||||
encoder.parameters(),
|
||||
quantizer.parameters(),
|
||||
decoder.parameters(),
|
||||
),
|
||||
lr=params.lr,
|
||||
betas=(0.5, 0.9),
|
||||
)
|
||||
optimizer_d = torch.optim.AdamW(
|
||||
itertools.chain(
|
||||
multi_scale_stft_discriminator.parameters(),
|
||||
multi_scale_discriminator.parameters(),
|
||||
multi_period_discriminator.parameters(),
|
||||
),
|
||||
lr=params.lr,
|
||||
betas=(0.5, 0.9),
|
||||
)
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
|
||||
|
||||
if checkpoints is not None:
|
||||
# load state_dict for optimizers
|
||||
if "optimizer_g" in checkpoints:
|
||||
logging.info("Loading optimizer_g state dict")
|
||||
optimizer_g.load_state_dict(checkpoints["optimizer_g"])
|
||||
if "optimizer_d" in checkpoints:
|
||||
logging.info("Loading optimizer_d state dict")
|
||||
optimizer_d.load_state_dict(checkpoints["optimizer_d"])
|
||||
|
||||
# load state_dict for schedulers
|
||||
if "scheduler_g" in checkpoints:
|
||||
logging.info("Loading scheduler_g state dict")
|
||||
scheduler_g.load_state_dict(checkpoints["scheduler_g"])
|
||||
if "scheduler_d" in checkpoints:
|
||||
logging.info("Loading scheduler_d state dict")
|
||||
scheduler_d.load_state_dict(checkpoints["scheduler_d"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
if params.inf_check:
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
train_dl = vctk.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = vctk.valid_cuts()
|
||||
valid_dl = vctk.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
logging.info(f"Start epoch {epoch}")
|
||||
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.print_diagnostics:
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint(
|
||||
filename=filename,
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
if rank == 0:
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
# step per epoch
|
||||
scheduler_g.step()
|
||||
scheduler_d.step()
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
VctkTtsDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/libritts/CODEC/encodec/utils.py
Symbolic link
1
egs/libritts/CODEC/encodec/utils.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../vctk/TTS/vits/utils.py
|
Loading…
x
Reference in New Issue
Block a user