mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
Merge fd31ed5b0b0bef24daea22e06bb481b5a0cd519e into 9293edc62f4a3ebf769d66cc037d4e67953440f5
This commit is contained in:
commit
9a2b5720c4
88
egs/librilight/SSL/local/analyze_codebook.py
Executable file
88
egs/librilight/SSL/local/analyze_codebook.py
Executable file
@ -0,0 +1,88 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from collections import Counter
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cuts-path",
|
||||||
|
type=str,
|
||||||
|
default="data/kmeans/librispeech_cuts_dev-clean.jsonl.gz",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-clusters",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_codebook(args):
|
||||||
|
cuts_path = Path(args.cuts_path)
|
||||||
|
assert cuts_path.is_file(), f"{cuts_path} does not exist"
|
||||||
|
|
||||||
|
logging.info(f"Loading {cuts_path}")
|
||||||
|
cut_set = CutSet.from_file(cuts_path)
|
||||||
|
|
||||||
|
cluster_counts = Counter()
|
||||||
|
logging.info("Analyzing codebook")
|
||||||
|
for cut in tqdm(cut_set):
|
||||||
|
kmeans = map(int, cut.custom["kmeans"].split())
|
||||||
|
cluster_counts.update(kmeans)
|
||||||
|
|
||||||
|
utilized_clusters = len(cluster_counts)
|
||||||
|
|
||||||
|
total_count = sum(cluster_counts.values())
|
||||||
|
counts = torch.tensor([cluster_counts[i] for i in range(args.num_clusters)])
|
||||||
|
normalized_counts = (counts / total_count).clamp(min=1e-10)
|
||||||
|
codebook_entropy = (
|
||||||
|
-(normalized_counts * normalized_counts.log()).sum()
|
||||||
|
* torch.log2(torch.tensor(torch.e))
|
||||||
|
).item()
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Codebook utilization rate: {utilized_clusters / args.num_clusters:%}"
|
||||||
|
)
|
||||||
|
logging.info(f"Codebook entropy: {codebook_entropy}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
args = get_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
analyze_codebook(args)
|
||||||
289
egs/librilight/SSL/local/extract_kmeans.py
Executable file
289
egs/librilight/SSL/local/extract_kmeans.py
Executable file
@ -0,0 +1,289 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import fairseq
|
||||||
|
import joblib
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, SupervisionSegment
|
||||||
|
from lhotse.utils import fastcopy
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
|
||||||
|
|
||||||
|
|
||||||
|
class ApplyKmeans(object):
|
||||||
|
def __init__(self, km_path):
|
||||||
|
self.km_model = joblib.load(km_path)
|
||||||
|
self.C_np = self.km_model.cluster_centers_.transpose()
|
||||||
|
self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True)
|
||||||
|
|
||||||
|
self.C = torch.from_numpy(self.C_np)
|
||||||
|
self.Cnorm = torch.from_numpy(self.Cnorm_np)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.C = self.C.cuda()
|
||||||
|
self.Cnorm = self.Cnorm.cuda()
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
dist = (
|
||||||
|
x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm
|
||||||
|
)
|
||||||
|
return dist.argmin(dim=1).cpu().numpy()
|
||||||
|
else:
|
||||||
|
dist = (
|
||||||
|
(x**2).sum(1, keepdims=True)
|
||||||
|
- 2 * np.matmul(x, self.C_np)
|
||||||
|
+ self.Cnorm_np
|
||||||
|
)
|
||||||
|
return np.argmin(dist, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--subset",
|
||||||
|
type=str,
|
||||||
|
default="small",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
type=str,
|
||||||
|
default="download/hubert_base_ls960.pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--kmeans-model-path",
|
||||||
|
type=str,
|
||||||
|
default="download/hubert_base_ls960_L9_km500.bin",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--start",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Process pieces starting from this number (inclusive).",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--stop",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Stop processing pieces until this number (exclusive).",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--window-duration",
|
||||||
|
type=float,
|
||||||
|
default=300.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--shift-duration",
|
||||||
|
type=float,
|
||||||
|
default=250.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def extract_and_save_one_cuts(
|
||||||
|
raw_cuts_path,
|
||||||
|
cuts_path,
|
||||||
|
model,
|
||||||
|
apply_kmeans,
|
||||||
|
do_normalize,
|
||||||
|
window_duration,
|
||||||
|
shift_duration,
|
||||||
|
):
|
||||||
|
logging.info(f"Loading {raw_cuts_path}")
|
||||||
|
cut_set = CutSet.from_file(raw_cuts_path)
|
||||||
|
|
||||||
|
logging.info("Extracting kmeans")
|
||||||
|
cuts = []
|
||||||
|
|
||||||
|
assert window_duration >= shift_duration
|
||||||
|
window_size = int(window_duration * 16000)
|
||||||
|
shift_size = int(shift_duration * 16000)
|
||||||
|
overlap_size = window_size - shift_size
|
||||||
|
out_overlap_size = get_out_length(overlap_size)
|
||||||
|
|
||||||
|
for cut in tqdm(cut_set):
|
||||||
|
assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}"
|
||||||
|
|
||||||
|
audio = cut.load_audio()
|
||||||
|
|
||||||
|
T = audio.shape[1]
|
||||||
|
start = 0
|
||||||
|
kmeans = []
|
||||||
|
while start < T:
|
||||||
|
real_window_size = min(window_size, T - start)
|
||||||
|
audio_window = audio[:, start : start + real_window_size]
|
||||||
|
|
||||||
|
x = (
|
||||||
|
torch.from_numpy(audio_window)
|
||||||
|
.float()
|
||||||
|
.to(next(model.parameters()).device)
|
||||||
|
)
|
||||||
|
if do_normalize:
|
||||||
|
x = torch.nn.functional.layer_norm(x, x.shape)
|
||||||
|
|
||||||
|
feature, _ = model.extract_features(
|
||||||
|
source=x,
|
||||||
|
padding_mask=None,
|
||||||
|
mask=False,
|
||||||
|
output_layer=9,
|
||||||
|
)
|
||||||
|
feature = feature.squeeze(0)
|
||||||
|
|
||||||
|
current_kmeans = apply_kmeans(feature).tolist()
|
||||||
|
|
||||||
|
if start == 0:
|
||||||
|
kmeans.extend(current_kmeans)
|
||||||
|
else:
|
||||||
|
kmeans.extend(current_kmeans[out_overlap_size:])
|
||||||
|
|
||||||
|
if T - start <= window_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
start += shift_size
|
||||||
|
|
||||||
|
kmeans = " ".join(map(str, kmeans))
|
||||||
|
|
||||||
|
cut_with_kmeans = fastcopy(
|
||||||
|
cut,
|
||||||
|
custom={"kmeans": kmeans},
|
||||||
|
)
|
||||||
|
cuts.append(cut_with_kmeans)
|
||||||
|
|
||||||
|
cuts = CutSet(cuts)
|
||||||
|
|
||||||
|
logging.info(f"Saving to {cuts_path}")
|
||||||
|
cuts.to_file(cuts_path)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_kmeans(args):
|
||||||
|
assert args.subset in ("small", "medium", "large"), f"{args.subset}"
|
||||||
|
|
||||||
|
output_dir = (
|
||||||
|
f"data/kmeans/{args.subset}_split" if args.subset != "small" else "data/kmeans"
|
||||||
|
)
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
assert output_dir.exists(), f"{output_dir} does not exist!"
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
prefix = "librilight"
|
||||||
|
|
||||||
|
apply_kmeans = ApplyKmeans(args.kmeans_model_path)
|
||||||
|
model, _, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||||
|
[args.model_path]
|
||||||
|
)
|
||||||
|
model = model[0].eval().to(device)
|
||||||
|
do_normalize = task.cfg.normalize
|
||||||
|
|
||||||
|
window_duration = args.window_duration
|
||||||
|
shift_duration = args.shift_duration
|
||||||
|
|
||||||
|
if args.subset == "small":
|
||||||
|
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
|
||||||
|
if cuts_path.is_file():
|
||||||
|
logging.info(f"{cuts_path} exists - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
raw_cuts_path = output_dir / f"{prefix}_cuts_{args.subset}_raw.jsonl.gz"
|
||||||
|
if not raw_cuts_path.is_file():
|
||||||
|
logging.info(f"{raw_cuts_path} does not exist - skipping it")
|
||||||
|
return
|
||||||
|
|
||||||
|
extract_and_save_one_cuts(
|
||||||
|
raw_cuts_path,
|
||||||
|
cuts_path,
|
||||||
|
model,
|
||||||
|
apply_kmeans,
|
||||||
|
do_normalize,
|
||||||
|
window_duration,
|
||||||
|
shift_duration,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
num_digits = 8 # num_digits is fixed by lhotse split-lazy
|
||||||
|
start = args.start
|
||||||
|
stop = args.stop
|
||||||
|
assert stop > start, "stop must be larger than start!"
|
||||||
|
|
||||||
|
for i in range(start, stop):
|
||||||
|
idx = f"{i}".zfill(num_digits)
|
||||||
|
logging.info(f"Processing {idx}/{stop - 1}")
|
||||||
|
|
||||||
|
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.{idx}.jsonl.gz"
|
||||||
|
if cuts_path.is_file():
|
||||||
|
logging.info(f"{cuts_path} exists - skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_cuts_path = (
|
||||||
|
output_dir / f"{prefix}_cuts_{args.subset}_raw.{idx}.jsonl.gz"
|
||||||
|
)
|
||||||
|
if not raw_cuts_path.is_file():
|
||||||
|
logging.info(f"{raw_cuts_path} does not exist - skipping it")
|
||||||
|
continue
|
||||||
|
|
||||||
|
extract_and_save_one_cuts(
|
||||||
|
raw_cuts_path,
|
||||||
|
cuts_path,
|
||||||
|
model,
|
||||||
|
apply_kmeans,
|
||||||
|
do_normalize,
|
||||||
|
window_duration,
|
||||||
|
shift_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_out_length(T):
|
||||||
|
conv_layers = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
|
||||||
|
for i, (out_channels, kernel_size, stride) in enumerate(conv_layers):
|
||||||
|
T = math.floor((T - kernel_size) / stride) + 1
|
||||||
|
|
||||||
|
return max(0, T)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
args = get_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
extract_kmeans(args)
|
||||||
107
egs/librilight/SSL/local/preprocess_librilight.py
Executable file
107
egs/librilight/SSL/local/preprocess_librilight.py
Executable file
@ -0,0 +1,107 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet
|
||||||
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_librilight(
|
||||||
|
dataset: Optional[str] = None,
|
||||||
|
):
|
||||||
|
src_dir = Path("data/manifests")
|
||||||
|
output_dir = Path("data/kmeans")
|
||||||
|
|
||||||
|
if dataset is None:
|
||||||
|
dataset_parts = (
|
||||||
|
"small",
|
||||||
|
"medium",
|
||||||
|
"large",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dataset_parts = dataset.split(" ", -1)
|
||||||
|
|
||||||
|
prefix = "librilight"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
|
manifests = read_manifests_if_cached(
|
||||||
|
dataset_parts=dataset_parts,
|
||||||
|
output_dir=src_dir,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
)
|
||||||
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
|
for partition, m in manifests.items():
|
||||||
|
cuts_filename = f"{prefix}_cuts_{partition}_raw.{suffix}"
|
||||||
|
if (output_dir / cuts_filename).is_file():
|
||||||
|
logging.info(f"{partition} already exists - skipping.")
|
||||||
|
continue
|
||||||
|
logging.info(f"Processing {partition}")
|
||||||
|
cut_set = CutSet.from_manifests(
|
||||||
|
recordings=m["recordings"],
|
||||||
|
supervisions=m["supervisions"],
|
||||||
|
)
|
||||||
|
cut_set = cut_set.trim_to_supervisions(
|
||||||
|
keep_overlapping=False, min_duration=None
|
||||||
|
)
|
||||||
|
logging.info(f"Saving to {output_dir / cuts_filename}")
|
||||||
|
cut_set.to_file(output_dir / cuts_filename)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
args = get_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
preprocess_librilight(
|
||||||
|
dataset=args.dataset,
|
||||||
|
)
|
||||||
100
egs/librilight/SSL/prepare.sh
Executable file
100
egs/librilight/SSL/prepare.sh
Executable file
@ -0,0 +1,100 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||||
|
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
nj=32
|
||||||
|
# run step 0 to step 4 by default
|
||||||
|
stage=0
|
||||||
|
stop_stage=4
|
||||||
|
|
||||||
|
# We assume dl_dir (download dir) contains the following
|
||||||
|
# directories and files. If not, they will be downloaded
|
||||||
|
# by this script automatically.
|
||||||
|
#
|
||||||
|
# - $dl_dir/LibriLight
|
||||||
|
# - small
|
||||||
|
# - medium
|
||||||
|
# - large
|
||||||
|
#
|
||||||
|
# You can download them from
|
||||||
|
# - https://dl.fbaipublicfiles.com/librilight/data/small.tar
|
||||||
|
# - https://dl.fbaipublicfiles.com/librilight/data/medium.tar
|
||||||
|
# - https://dl.fbaipublicfiles.com/librilight/data/large.tar
|
||||||
|
|
||||||
|
dl_dir=$PWD/download
|
||||||
|
|
||||||
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
# All files generated by this script are saved in "data".
|
||||||
|
# You can safely remove "data" and rerun this script to regenerate it.
|
||||||
|
mkdir -p data
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
log "Running prepare.sh"
|
||||||
|
|
||||||
|
log "dl_dir: $dl_dir"
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "Stage 1: Prepare Libri-Light manifest"
|
||||||
|
# We assume that you have downloaded the Libri-Light corpus
|
||||||
|
# to $dl_dir/LibriLight
|
||||||
|
mkdir -p data/manifests
|
||||||
|
if [ ! -e data/manifests/.librilight.done ]; then
|
||||||
|
lhotse prepare librilight -j $nj $dl_dir/LibriLight data/manifests
|
||||||
|
touch data/manifests/.librilight.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
|
log "Stage 2: Preprocess Libri-Light manifest"
|
||||||
|
mkdir -p data/kmeans
|
||||||
|
if [ ! -f data/kmeans/.preprocess_complete ]; then
|
||||||
|
python3 ./local/preprocess_librilight.py
|
||||||
|
touch data/kmeans/.preprocess_complete
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
log "Stage 3: Split medium and large subset into pieces"
|
||||||
|
num_per_split=10000
|
||||||
|
split_dir=data/kmeans/medium_split
|
||||||
|
if [ ! -f $split_dir/.split_completed ]; then
|
||||||
|
lhotse split-lazy ./data/kmeans/librilight_cuts_medium_raw.jsonl.gz $split_dir $num_per_split
|
||||||
|
touch $split_dir/.split_completed
|
||||||
|
fi
|
||||||
|
split_dir=data/kmeans/large_split
|
||||||
|
if [ ! -f $split_dir/.split_completed ]; then
|
||||||
|
lhotse split-lazy ./data/kmeans/librilight_cuts_large_raw.jsonl.gz $split_dir $num_per_split
|
||||||
|
touch $split_dir/.split_completed
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "Stage 4: Extract SSL target for librilight"
|
||||||
|
if [ ! -e download/hubert_base_ls960.pt ]; then
|
||||||
|
wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt -P download
|
||||||
|
fi
|
||||||
|
if [ ! -e download/hubert_base_ls960_L9_km500.bin ]; then
|
||||||
|
wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin -P download
|
||||||
|
fi
|
||||||
|
if [ ! -e data/kmeans/.extract_small.done ]; then
|
||||||
|
./local/extract_kmeans.py --subset small
|
||||||
|
touch data/kmeans/.extract_small.done
|
||||||
|
fi
|
||||||
|
if [ ! -e data/kmeans/.extract_medium.done ]; then
|
||||||
|
./local/extract_kmeans.py --subset medium
|
||||||
|
touch data/kmeans/.extract_medium.done
|
||||||
|
fi
|
||||||
|
if [ ! -e data/kmeans/.extract_large.done ]; then
|
||||||
|
./local/extract_kmeans.py --subset large
|
||||||
|
touch data/kmeans/.extract_large.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
116
egs/librilight/SSL/run_multi_node_multi_gpu.sh
Executable file
116
egs/librilight/SSL/run_multi_node_multi_gpu.sh
Executable file
@ -0,0 +1,116 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
#
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Yifan Yang)
|
||||||
|
#
|
||||||
|
# This script is the entry point to start model training
|
||||||
|
# with multi-node multi-GPU.
|
||||||
|
#
|
||||||
|
# Read the usage instructions below for how to run this script.
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# DDP related parameters
|
||||||
|
master_addr=
|
||||||
|
node_rank=
|
||||||
|
num_nodes=4
|
||||||
|
master_port=12354
|
||||||
|
|
||||||
|
. shared/parse_options.sh
|
||||||
|
|
||||||
|
function usage() {
|
||||||
|
echo "Usage: "
|
||||||
|
echo ""
|
||||||
|
echo " $0 \\"
|
||||||
|
echo " --master-addr <IP of master> \\"
|
||||||
|
echo " --master-port <Port of master> \\"
|
||||||
|
echo " --node-rank <rank of this node> \\"
|
||||||
|
echo " --num-nodes <Number of node>"
|
||||||
|
echo ""
|
||||||
|
echo " --master-addr The ip address of the master node."
|
||||||
|
echo " --master-port The port of the master node."
|
||||||
|
echo " --node-rank Rank of this node."
|
||||||
|
echo " --num-nodes Number of nodes in DDP training."
|
||||||
|
echo ""
|
||||||
|
echo "Usage example:"
|
||||||
|
echo "Suppose you want to use DDP with two machines:"
|
||||||
|
echo " (1) Machine 1 has 4 GPUs. You want to use"
|
||||||
|
echo " GPU 0, 1, and 3 for training"
|
||||||
|
echo " IP of machine 1 is: 10.177.41.71"
|
||||||
|
echo " (2) Machine 2 has 4 GPUs. You want to use"
|
||||||
|
echo " GPU 0, 2, and 3 for training"
|
||||||
|
echo " IP of machine 2 is: 10.177.41.72"
|
||||||
|
echo "You want to select machine 1 as the master node and"
|
||||||
|
echo "assume that the port 1234 is free on machine 1."
|
||||||
|
echo ""
|
||||||
|
echo "On machine 1, you run:"
|
||||||
|
echo ""
|
||||||
|
echo " export CUDA_VISIBLE_DEVICES=\"0,1,3\""
|
||||||
|
echo " ./run_multi_node_multi_gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 0 --num-nodes 2"
|
||||||
|
echo ""
|
||||||
|
echo "On machine 2, you run:"
|
||||||
|
echo ""
|
||||||
|
echo " export CUDA_VISIBLE_DEVICES=\"0,2,3\""
|
||||||
|
echo " ./run_multi_node_multi_gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 1 --num-nodes 2"
|
||||||
|
echo ""
|
||||||
|
echo "Note 1:"
|
||||||
|
echo " You use CUDA_VISIBLE_DEVICES to decide which GPUs are used for training."
|
||||||
|
echo ""
|
||||||
|
echo "Note 2:"
|
||||||
|
echo " If you use torch < 1.9.0, then every node has to use the same number of GPUs for training."
|
||||||
|
echo " If you use torch >= 1.9.0, different nodes can have a different number of GPUs for training."
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
default='\033[0m'
|
||||||
|
bold='\033[1m'
|
||||||
|
red='\033[31m'
|
||||||
|
|
||||||
|
function error() {
|
||||||
|
printf "${bold}${red}[ERROR]${default} $1\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
[ ! -z $CUDA_VISIBLE_DEVICES ] || ( echo; error "Please set CUDA_VISIBLE_DEVICES"; echo; usage )
|
||||||
|
[ ! -z $master_addr ] || ( echo; error "Please set --master-addr"; echo; usage )
|
||||||
|
[ ! -z $master_port ] || ( echo; error "Please set --master-port"; echo; usage )
|
||||||
|
[ ! -z $node_rank ] || ( echo; error "Please set --node-rank"; echo; usage )
|
||||||
|
[ ! -z $num_nodes ] || ( echo; error "Please set --num-nodes"; echo; usage )
|
||||||
|
|
||||||
|
# Number of GPUs this node has
|
||||||
|
num_gpus=$(python3 -c "s=\"$CUDA_VISIBLE_DEVICES\"; print(len(s.split(',')))")
|
||||||
|
|
||||||
|
echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
|
||||||
|
echo "num_gpus: $num_gpus"
|
||||||
|
echo "master_addr: $master_addr"
|
||||||
|
|
||||||
|
export MASTER_ADDR=$master_addr
|
||||||
|
export MASTER_PORT=$master_port
|
||||||
|
|
||||||
|
set -x
|
||||||
|
|
||||||
|
torchrun \
|
||||||
|
--nproc_per_node $num_gpus \
|
||||||
|
--nnodes $num_nodes \
|
||||||
|
--node_rank $node_rank \
|
||||||
|
--master_addr $master_addr \
|
||||||
|
--master_port $master_port \
|
||||||
|
zipformer/pretrain.py \
|
||||||
|
--use-multi-node 1 \
|
||||||
|
--master-port $master_port \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--exp-dir zipformer/exp_pretrain \
|
||||||
|
--max-duration 350 \
|
||||||
|
--quadratic-duration 1024 \
|
||||||
|
--accum-grad 1 \
|
||||||
|
--do-normalize 1 \
|
||||||
|
--mask-prob 0.8 \
|
||||||
|
--dropout-input 0.0 \
|
||||||
|
--dropout-features 0.0 \
|
||||||
|
--feature-grad-mult 1.0 \
|
||||||
|
--num-encoder-layers 2,2,4,5,4,2 \
|
||||||
|
--feedforward-dim 768,1536,2048,3072,2048,1536 \
|
||||||
|
--encoder-dim 256,512,768,1024,768,512 \
|
||||||
|
--encoder-unmasked-dim 256,256,256,320,256,256 \
|
||||||
|
--base-lr 0.045
|
||||||
1
egs/librilight/SSL/shared
Symbolic link
1
egs/librilight/SSL/shared
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../icefall/shared/
|
||||||
@ -1 +0,0 @@
|
|||||||
../../../librispeech/SSL/zipformer/asr_datamodule.py
|
|
||||||
File diff suppressed because it is too large
Load Diff
149
egs/librilight/SSL/zipformer/pretrain.py
Normal file → Executable file
149
egs/librilight/SSL/zipformer/pretrain.py
Normal file → Executable file
@ -1,11 +1,10 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
# Wei Kang,
|
# Wei Kang,
|
||||||
# Mingshuang Luo,
|
# Mingshuang Luo,
|
||||||
# Zengwei Yao,
|
# Zengwei Yao,
|
||||||
# Yifan Yang,
|
# Yifan Yang,
|
||||||
# Daniel Povey)
|
# Daniel Povey)
|
||||||
#
|
|
||||||
# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
|
# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
@ -21,22 +20,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
|
||||||
Usage:
|
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
|
||||||
|
|
||||||
# For hubert model pretraining:
|
|
||||||
./zipformer/pretrain.py \
|
|
||||||
--world-size 8 \
|
|
||||||
--num-epochs 400 \
|
|
||||||
--start-epoch 1 \
|
|
||||||
--use-fp16 1 \
|
|
||||||
--exp-dir zipformer/exp \
|
|
||||||
--max-duration 87.5 \
|
|
||||||
--accum-grad 4
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
@ -68,7 +51,13 @@ from icefall.checkpoint import (
|
|||||||
save_checkpoint_with_global_batch_idx,
|
save_checkpoint_with_global_batch_idx,
|
||||||
update_averaged_model,
|
update_averaged_model,
|
||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import (
|
||||||
|
cleanup_dist,
|
||||||
|
get_local_rank,
|
||||||
|
get_rank,
|
||||||
|
get_world_size,
|
||||||
|
setup_dist,
|
||||||
|
)
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -398,19 +387,21 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
and the value should be the multiple of 4, for faster computation""",
|
and the value should be the multiple of 4, for faster computation""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--untie-final-proj",
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help="use separate projection for each target",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-multi-node",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True if using multi-node multi-GPU.
|
||||||
|
You are not supposed to set it directly.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size",
|
"--world-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -481,17 +472,17 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr-epochs",
|
"--lr-hours",
|
||||||
type=float,
|
type=float,
|
||||||
default=10.5,
|
default=10000,
|
||||||
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
help="""Number of hours that affects how rapidly the learning rate decreases.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--warmup-batches",
|
"--warmup-batches",
|
||||||
type=float,
|
type=float,
|
||||||
default=5000,
|
default=1000,
|
||||||
help="Eden warmup steps",
|
help="Eden warmup steps",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -541,7 +532,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every-n",
|
"--save-every-n",
|
||||||
type=int,
|
type=int,
|
||||||
default=100000,
|
default=10000,
|
||||||
help="""Save checkpoint after processing this number of batches"
|
help="""Save checkpoint after processing this number of batches"
|
||||||
periodically. We save checkpoint to exp-dir/ whenever
|
periodically. We save checkpoint to exp-dir/ whenever
|
||||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||||
@ -554,7 +545,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--keep-last-k",
|
"--keep-last-k",
|
||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=100000,
|
||||||
help="""Only keep this number of checkpoints on disk.
|
help="""Only keep this number of checkpoints on disk.
|
||||||
For instance, if it is 3, there are only 3 checkpoints
|
For instance, if it is 3, there are only 3 checkpoints
|
||||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
||||||
@ -578,7 +569,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--accum-grad",
|
"--accum-grad",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=1,
|
||||||
help="""update gradient when batch_idx_train % accum_grad == 0.
|
help="""update gradient when batch_idx_train % accum_grad == 0.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -591,17 +582,24 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sample-size",
|
"--max-keep-size",
|
||||||
type=float,
|
type=int,
|
||||||
default=250000,
|
default=1024000,
|
||||||
help="max sample size",
|
help="exclude sample longer than this.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min-sample-size",
|
"--min-keep-size",
|
||||||
type=float,
|
type=float,
|
||||||
default=32000,
|
default=32000,
|
||||||
help="min sample size",
|
help="exclude sample longer less than this.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sample-size",
|
||||||
|
type=float,
|
||||||
|
default=1024000,
|
||||||
|
help="max sample size to crop to for batching.",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -953,6 +951,13 @@ def train_one_epoch(
|
|||||||
if sub_batch_idx % params.accum_grad == params.accum_grad - 1:
|
if sub_batch_idx % params.accum_grad == params.accum_grad - 1:
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
|
# Use the number of hours of speech to adjust the learning rate
|
||||||
|
scheduler.step_epoch(
|
||||||
|
params.batch_idx_train
|
||||||
|
* params.max_duration
|
||||||
|
* params.world_size
|
||||||
|
/ 3600
|
||||||
|
)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
@ -960,10 +965,10 @@ def train_one_epoch(
|
|||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except: # noqa
|
except Exception as e: # noqa
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
display_and_save_batch(batch, params=params)
|
display_and_save_batch(batch, params=params)
|
||||||
raise
|
raise e
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
@ -1064,7 +1069,7 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.accum_grad != params.accum_grad - 1:
|
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
@ -1089,8 +1094,15 @@ def run(rank, world_size, args):
|
|||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
fix_random_seed(params.seed)
|
fix_random_seed(params.seed)
|
||||||
|
|
||||||
|
if params.use_multi_node:
|
||||||
|
local_rank = get_local_rank()
|
||||||
|
else:
|
||||||
|
local_rank = rank
|
||||||
|
logging.info(f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}")
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port, params.use_multi_node)
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
logging.info("Training started")
|
logging.info("Training started")
|
||||||
@ -1102,8 +1114,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", local_rank)
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -1126,7 +1138,7 @@ def run(rank, world_size, args):
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
|
||||||
|
|
||||||
optimizer = ScaledAdam(
|
optimizer = ScaledAdam(
|
||||||
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
|
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
|
||||||
@ -1137,7 +1149,7 @@ def run(rank, world_size, args):
|
|||||||
scheduler = Eden(
|
scheduler = Eden(
|
||||||
optimizer,
|
optimizer,
|
||||||
params.lr_batches,
|
params.lr_batches,
|
||||||
params.lr_epochs,
|
params.lr_hours,
|
||||||
params.warmup_batches,
|
params.warmup_batches,
|
||||||
params.warmup_start,
|
params.warmup_start,
|
||||||
)
|
)
|
||||||
@ -1165,7 +1177,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
librilight = LibriLightDataModule(args)
|
librilight = LibriLightDataModule(args)
|
||||||
|
|
||||||
train_cuts = librilight.train_all_shuf_cuts()
|
train_cuts = librilight.all_shuf_cuts()
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
@ -1177,11 +1189,11 @@ def run(rank, world_size, args):
|
|||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
if (
|
if (
|
||||||
c.duration < params.min_sample_size / params.sample_rate
|
c.duration < params.min_keep_size / params.sample_rate
|
||||||
or c.duration > params.max_sample_size / params.sample_rate
|
or c.duration > params.max_keep_size / params.sample_rate
|
||||||
):
|
):
|
||||||
# logging.warning(
|
# logging.warning(
|
||||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
# )
|
# )
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1198,6 +1210,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
train_dl = librilight.train_dataloaders(
|
train_dl = librilight.train_dataloaders(
|
||||||
train_cuts,
|
train_cuts,
|
||||||
|
max_sample_size=params.max_sample_size,
|
||||||
sample_rate=params.sample_rate,
|
sample_rate=params.sample_rate,
|
||||||
label_rate=params.label_rate,
|
label_rate=params.label_rate,
|
||||||
random_crop=params.random_crop,
|
random_crop=params.random_crop,
|
||||||
@ -1205,6 +1218,8 @@ def run(rank, world_size, args):
|
|||||||
num_classes=params.num_classes,
|
num_classes=params.num_classes,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
sampler_state_dict=sampler_state_dict,
|
sampler_state_dict=sampler_state_dict,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = librilight.dev_clean_cuts()
|
valid_cuts = librilight.dev_clean_cuts()
|
||||||
@ -1213,12 +1228,15 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
valid_dl = librilight.valid_dataloaders(
|
valid_dl = librilight.valid_dataloaders(
|
||||||
valid_cuts,
|
valid_cuts,
|
||||||
|
max_sample_size=params.max_sample_size,
|
||||||
sample_rate=params.sample_rate,
|
sample_rate=params.sample_rate,
|
||||||
label_rate=params.label_rate,
|
label_rate=params.label_rate,
|
||||||
random_crop=params.random_crop,
|
random_crop=params.random_crop,
|
||||||
pad_audio=False,
|
pad_audio=False,
|
||||||
num_classes=params.num_classes,
|
num_classes=params.num_classes,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.sanity_check and not params.print_diagnostics:
|
if params.sanity_check and not params.print_diagnostics:
|
||||||
@ -1235,7 +1253,6 @@ def run(rank, world_size, args):
|
|||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
scheduler.step_epoch(epoch - 1)
|
|
||||||
fix_random_seed(params.seed + epoch - 1)
|
fix_random_seed(params.seed + epoch - 1)
|
||||||
train_dl.sampler.set_epoch(epoch - 1)
|
train_dl.sampler.set_epoch(epoch - 1)
|
||||||
|
|
||||||
@ -1339,7 +1356,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
f"(={crit_values[criterion]}) ..."
|
f"(={crit_values[criterion]}) ..."
|
||||||
)
|
)
|
||||||
display_and_save_batch(batch, params=params)
|
display_and_save_batch(batch, params=params)
|
||||||
raise
|
raise e
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
)
|
)
|
||||||
@ -1351,12 +1368,18 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
world_size = args.world_size
|
if args.use_multi_node:
|
||||||
assert world_size >= 1
|
rank = get_rank()
|
||||||
if world_size > 1:
|
world_size = get_world_size()
|
||||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
args.world_size = world_size
|
||||||
|
run(rank=rank, world_size=world_size, args=args)
|
||||||
else:
|
else:
|
||||||
run(rank=0, world_size=1, args=args)
|
world_size = args.world_size
|
||||||
|
assert world_size >= 1
|
||||||
|
if world_size > 1:
|
||||||
|
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||||
|
else:
|
||||||
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
# Copyright 2021 Piotr Żelasko
|
# Copyright 2024 Xiaomi Corporation (Author: Yifan Yang)
|
||||||
# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang)
|
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -24,9 +23,10 @@ from functools import lru_cache
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import lhotse
|
||||||
import torch
|
import torch
|
||||||
from dataset import HubertDataset
|
from dataset import HubertDataset
|
||||||
from lhotse import CutSet, combine, load_manifest_lazy
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
|
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -46,7 +46,7 @@ class LibriLightDataModule:
|
|||||||
"""
|
"""
|
||||||
DataModule for SSL experiments.
|
DataModule for SSL experiments.
|
||||||
It assumes there is always one train and valid dataloader,
|
It assumes there is always one train and valid dataloader,
|
||||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
but there can be multiple test dataloaders (e.g. LibriLight test-clean
|
||||||
and test-other).
|
and test-other).
|
||||||
|
|
||||||
It contains all the common data pipeline modules used in SSL
|
It contains all the common data pipeline modules used in SSL
|
||||||
@ -63,7 +63,7 @@ class LibriLightDataModule:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group(
|
group = parser.add_argument_group(
|
||||||
title="ASR SSL related options",
|
title="SSL data related options",
|
||||||
description="These options are used for the preparation of "
|
description="These options are used for the preparation of "
|
||||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
"effective batch sizes, sampling strategies.",
|
"effective batch sizes, sampling strategies.",
|
||||||
@ -92,10 +92,29 @@ class LibriLightDataModule:
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-buckets",
|
"--num-buckets",
|
||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=1000,
|
||||||
help="The number of buckets for the DynamicBucketingSampler"
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
"(you might want to increase it for larger datasets).",
|
"(you might want to increase it for larger datasets).",
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-cuts-for-bins-estimate",
|
||||||
|
type=int,
|
||||||
|
default=1000000,
|
||||||
|
help="We will draw this many cuts to estimate the duration"
|
||||||
|
"bins for creating similar-duration buckets. Larger number"
|
||||||
|
"means a better estimate to the data distribution, possibly"
|
||||||
|
"at a longer init cost.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--quadratic-duration",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="When set, it adds an extra penalty that's quadratic"
|
||||||
|
"in size w.r.t. a cuts duration. This helps get a more"
|
||||||
|
"even GPU utilization across different input lengths when"
|
||||||
|
"models have quadratic input complexity. Set between 15"
|
||||||
|
"and 40 for transformers.",
|
||||||
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--shuffle",
|
"--shuffle",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -112,7 +131,7 @@ class LibriLightDataModule:
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=8,
|
||||||
help="The number of training dataloader workers that "
|
help="The number of training dataloader workers that "
|
||||||
"collect the batches.",
|
"collect the batches.",
|
||||||
)
|
)
|
||||||
@ -126,12 +145,13 @@ class LibriLightDataModule:
|
|||||||
"--random-crop",
|
"--random-crop",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="audio sample rate",
|
help="always crop from the beginning if false",
|
||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self,
|
self,
|
||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
|
max_sample_size: Optional[int] = None,
|
||||||
sample_rate: float = 16000,
|
sample_rate: float = 16000,
|
||||||
label_rate: float = 50,
|
label_rate: float = 50,
|
||||||
random_crop: bool = True,
|
random_crop: bool = True,
|
||||||
@ -139,6 +159,8 @@ class LibriLightDataModule:
|
|||||||
num_classes: list = [504],
|
num_classes: list = [504],
|
||||||
do_normalize: bool = True,
|
do_normalize: bool = True,
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -149,6 +171,7 @@ class LibriLightDataModule:
|
|||||||
"""
|
"""
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
train = HubertDataset(
|
train = HubertDataset(
|
||||||
|
max_sample_size=max_sample_size,
|
||||||
sample_rate=sample_rate,
|
sample_rate=sample_rate,
|
||||||
label_rate=label_rate,
|
label_rate=label_rate,
|
||||||
random_crop=random_crop,
|
random_crop=random_crop,
|
||||||
@ -162,9 +185,14 @@ class LibriLightDataModule:
|
|||||||
train_sampler = DynamicBucketingSampler(
|
train_sampler = DynamicBucketingSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
|
quadratic_duration=self.args.quadratic_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
|
buffer_size=self.args.num_buckets * 2000,
|
||||||
|
num_cuts_for_bins_estimate=self.args.num_cuts_for_bins_estimate,
|
||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SimpleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
@ -172,6 +200,8 @@ class LibriLightDataModule:
|
|||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
@ -198,15 +228,19 @@ class LibriLightDataModule:
|
|||||||
def valid_dataloaders(
|
def valid_dataloaders(
|
||||||
self,
|
self,
|
||||||
cuts_valid: CutSet,
|
cuts_valid: CutSet,
|
||||||
|
max_sample_size: Optional[int] = None,
|
||||||
sample_rate: float = 16000,
|
sample_rate: float = 16000,
|
||||||
label_rate: float = 50,
|
label_rate: float = 50,
|
||||||
random_crop: bool = True,
|
random_crop: bool = True,
|
||||||
pad_audio: bool = False,
|
pad_audio: bool = False,
|
||||||
num_classes: list = [504],
|
num_classes: list = [504],
|
||||||
do_normalize: bool = True,
|
do_normalize: bool = True,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
validate = HubertDataset(
|
validate = HubertDataset(
|
||||||
|
max_sample_size=max_sample_size,
|
||||||
sample_rate=sample_rate,
|
sample_rate=sample_rate,
|
||||||
label_rate=label_rate,
|
label_rate=label_rate,
|
||||||
random_crop=random_crop,
|
random_crop=random_crop,
|
||||||
@ -217,7 +251,10 @@ class LibriLightDataModule:
|
|||||||
valid_sampler = DynamicBucketingSampler(
|
valid_sampler = DynamicBucketingSampler(
|
||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
|
quadratic_duration=self.args.quadratic_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
valid_dl = DataLoader(
|
valid_dl = DataLoader(
|
||||||
@ -230,81 +267,11 @@ class LibriLightDataModule:
|
|||||||
|
|
||||||
return valid_dl
|
return valid_dl
|
||||||
|
|
||||||
def test_dataloaders(
|
|
||||||
self,
|
|
||||||
cuts: CutSet,
|
|
||||||
sample_rate: float = 16000,
|
|
||||||
label_rate: float = 50,
|
|
||||||
random_crop: bool = True,
|
|
||||||
pad_audio: bool = False,
|
|
||||||
num_classes: list = [504],
|
|
||||||
do_normalize: bool = True,
|
|
||||||
) -> DataLoader:
|
|
||||||
logging.debug("About to create test dataset")
|
|
||||||
test = HubertDataset(
|
|
||||||
sample_rate=sample_rate,
|
|
||||||
label_rate=label_rate,
|
|
||||||
random_crop=random_crop,
|
|
||||||
pad_audio=pad_audio,
|
|
||||||
num_classes=num_classes,
|
|
||||||
do_normalize=do_normalize,
|
|
||||||
)
|
|
||||||
sampler = DynamicBucketingSampler(
|
|
||||||
cuts,
|
|
||||||
max_duration=self.args.max_duration,
|
|
||||||
shuffle=False,
|
|
||||||
)
|
|
||||||
logging.debug("About to create test dataloader")
|
|
||||||
test_dl = DataLoader(
|
|
||||||
test,
|
|
||||||
batch_size=None,
|
|
||||||
sampler=sampler,
|
|
||||||
num_workers=self.args.num_workers,
|
|
||||||
)
|
|
||||||
return test_dl
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def small_cuts(self) -> CutSet:
|
def all_shuf_cuts(self) -> CutSet:
|
||||||
logging.info("About to get small cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librilight_cuts_small.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def medium_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get medium cuts")
|
|
||||||
filenames = glob.glob(
|
|
||||||
f"{self.args.manifest_dir}/medium_splits/librilight_cuts_medium.*.jsonl.gz"
|
|
||||||
)
|
|
||||||
pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz")
|
|
||||||
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
|
||||||
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
|
||||||
sorted_filenames = [f[1] for f in idx_filenames]
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Loading LibriLight medium {len(sorted_filenames)} splits in lazy mode"
|
"About to get the shuffled librilight small, medium and large cuts"
|
||||||
)
|
)
|
||||||
|
|
||||||
return combine(load_manifest_lazy(p) for p in sorted_filenames)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def large_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get large cuts")
|
|
||||||
filenames = glob.glob(
|
|
||||||
f"{self.args.manifest_dir}/large_splits/librilight_cuts_large.*.jsonl.gz"
|
|
||||||
)
|
|
||||||
pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz")
|
|
||||||
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
|
||||||
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
|
||||||
sorted_filenames = [f[1] for f in idx_filenames]
|
|
||||||
logging.info(
|
|
||||||
f"Loading LibriLight large {len(sorted_filenames)} splits in lazy mode"
|
|
||||||
)
|
|
||||||
|
|
||||||
return combine(load_manifest_lazy(p) for p in sorted_filenames)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def train_all_shuf_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get the shuffled small, medium and large cuts")
|
|
||||||
small_cuts = self.small_cuts()
|
small_cuts = self.small_cuts()
|
||||||
medium_cuts = self.medium_cuts()
|
medium_cuts = self.medium_cuts()
|
||||||
large_cuts = self.large_cuts()
|
large_cuts = self.large_cuts()
|
||||||
@ -313,22 +280,60 @@ class LibriLightDataModule:
|
|||||||
medium_cuts,
|
medium_cuts,
|
||||||
large_cuts,
|
large_cuts,
|
||||||
weights=[
|
weights=[
|
||||||
122867, # len(small_cuts)
|
229051, # len(small_cuts)
|
||||||
1104071, # len(medium_cuts)
|
2022949, # len(medium_cuts)
|
||||||
11012085, # len(large_cuts)
|
19883414, # len(large_cuts)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_clean_cuts(self) -> CutSet:
|
def dev_clean_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev-clean cuts")
|
logging.info("About to get librispeech dev-clean cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_other_cuts(self) -> CutSet:
|
def small_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev-other cuts")
|
logging.info("About to get librilight small cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
self.args.manifest_dir / "librilight_cuts_small.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def medium_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get librilight medium cuts")
|
||||||
|
filenames = glob.glob(
|
||||||
|
str(
|
||||||
|
self.args.manifest_dir
|
||||||
|
/ "medium_split"
|
||||||
|
/ "librilight_cuts_medium.*.jsonl.gz"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz")
|
||||||
|
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
||||||
|
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
||||||
|
sorted_filenames = [f[1] for f in idx_filenames]
|
||||||
|
logging.info(
|
||||||
|
f"Loading Libri-Light medium {len(sorted_filenames)} splits in lazy mode"
|
||||||
|
)
|
||||||
|
return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def large_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get librilight large cuts")
|
||||||
|
filenames = glob.glob(
|
||||||
|
str(
|
||||||
|
self.args.manifest_dir
|
||||||
|
/ "large_split"
|
||||||
|
/ "librilight_cuts_large.*.jsonl.gz"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz")
|
||||||
|
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
||||||
|
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
||||||
|
sorted_filenames = [f[1] for f in idx_filenames]
|
||||||
|
logging.info(
|
||||||
|
f"Loading Libri-Light large {len(sorted_filenames)} splits in lazy mode"
|
||||||
|
)
|
||||||
|
return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames)
|
||||||
|
|||||||
@ -132,6 +132,8 @@ class LibriSpeechAsrDataModule:
|
|||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
do_normalize: bool,
|
do_normalize: bool,
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -150,7 +152,10 @@ class LibriSpeechAsrDataModule:
|
|||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
|
buffer_size=self.args.num_buckets * 2000,
|
||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SimpleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
@ -158,6 +163,8 @@ class LibriSpeechAsrDataModule:
|
|||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
@ -181,13 +188,21 @@ class LibriSpeechAsrDataModule:
|
|||||||
|
|
||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
def valid_dataloaders(self, cuts_valid: CutSet, do_normalize: bool) -> DataLoader:
|
def valid_dataloaders(
|
||||||
|
self,
|
||||||
|
cuts_valid: CutSet,
|
||||||
|
do_normalize: bool,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
|
) -> DataLoader:
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
validate = HubertAsrDataset(do_normalize=do_normalize)
|
validate = HubertAsrDataset(do_normalize=do_normalize)
|
||||||
valid_sampler = DynamicBucketingSampler(
|
valid_sampler = DynamicBucketingSampler(
|
||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
valid_dl = DataLoader(
|
valid_dl = DataLoader(
|
||||||
|
|||||||
0
egs/librispeech/SSL/hubert/decode.py
Normal file → Executable file
0
egs/librispeech/SSL/hubert/decode.py
Normal file → Executable file
0
egs/librispeech/SSL/hubert/decode_ce.py
Normal file → Executable file
0
egs/librispeech/SSL/hubert/decode_ce.py
Normal file → Executable file
4
egs/librispeech/SSL/hubert/finetune.py
Normal file → Executable file
4
egs/librispeech/SSL/hubert/finetune.py
Normal file → Executable file
@ -1091,6 +1091,8 @@ def run(rank, world_size, args):
|
|||||||
train_cuts,
|
train_cuts,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
sampler_state_dict=sampler_state_dict,
|
sampler_state_dict=sampler_state_dict,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
@ -1099,6 +1101,8 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = librispeech.valid_dataloaders(
|
valid_dl = librispeech.valid_dataloaders(
|
||||||
valid_cuts,
|
valid_cuts,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.sanity_check and not params.print_diagnostics:
|
if params.sanity_check and not params.print_diagnostics:
|
||||||
|
|||||||
4
egs/librispeech/SSL/hubert/finetune_ce.py
Normal file → Executable file
4
egs/librispeech/SSL/hubert/finetune_ce.py
Normal file → Executable file
@ -1091,6 +1091,8 @@ def run(rank, world_size, args):
|
|||||||
train_cuts,
|
train_cuts,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
sampler_state_dict=sampler_state_dict,
|
sampler_state_dict=sampler_state_dict,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
@ -1099,6 +1101,8 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = librispeech.valid_dataloaders(
|
valid_dl = librispeech.valid_dataloaders(
|
||||||
valid_cuts,
|
valid_cuts,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.sanity_check and not params.print_diagnostics:
|
if params.sanity_check and not params.print_diagnostics:
|
||||||
|
|||||||
0
egs/librispeech/SSL/hubert/pretrain.py
Normal file → Executable file
0
egs/librispeech/SSL/hubert/pretrain.py
Normal file → Executable file
0
egs/librispeech/SSL/hubert/pretrain_ce.py
Normal file → Executable file
0
egs/librispeech/SSL/hubert/pretrain_ce.py
Normal file → Executable file
@ -144,6 +144,8 @@ class LibriSpeechDataModule:
|
|||||||
num_classes: list = [504],
|
num_classes: list = [504],
|
||||||
do_normalize: bool = True,
|
do_normalize: bool = True,
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -170,7 +172,10 @@ class LibriSpeechDataModule:
|
|||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
|
buffer_size=self.args.num_buckets * 2000,
|
||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SimpleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
@ -178,6 +183,8 @@ class LibriSpeechDataModule:
|
|||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
@ -211,6 +218,8 @@ class LibriSpeechDataModule:
|
|||||||
pad_audio: bool = False,
|
pad_audio: bool = False,
|
||||||
num_classes: list = [504],
|
num_classes: list = [504],
|
||||||
do_normalize: bool = True,
|
do_normalize: bool = True,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
validate = HubertDataset(
|
validate = HubertDataset(
|
||||||
@ -226,6 +235,8 @@ class LibriSpeechDataModule:
|
|||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
valid_dl = DataLoader(
|
valid_dl = DataLoader(
|
||||||
|
|||||||
0
egs/librispeech/SSL/zipformer/decode.py
Normal file → Executable file
0
egs/librispeech/SSL/zipformer/decode.py
Normal file → Executable file
4
egs/librispeech/SSL/zipformer/finetune.py
Normal file → Executable file
4
egs/librispeech/SSL/zipformer/finetune.py
Normal file → Executable file
@ -1388,6 +1388,8 @@ def run(rank, world_size, args):
|
|||||||
train_cuts,
|
train_cuts,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
sampler_state_dict=sampler_state_dict,
|
sampler_state_dict=sampler_state_dict,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
@ -1396,6 +1398,8 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = librispeech.valid_dataloaders(
|
valid_dl = librispeech.valid_dataloaders(
|
||||||
valid_cuts,
|
valid_cuts,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.sanity_check and not params.print_diagnostics:
|
if params.sanity_check and not params.print_diagnostics:
|
||||||
|
|||||||
@ -296,7 +296,6 @@ class HubertModel(nn.Module):
|
|||||||
|
|
||||||
self.layer_norm = LayerNorm(self.embed)
|
self.layer_norm = LayerNorm(self.embed)
|
||||||
|
|
||||||
self.untie_final_proj = cfg.untie_final_proj
|
|
||||||
self.final_proj = nn.Linear(encoder_output_dim, sum(cfg.num_classes))
|
self.final_proj = nn.Linear(encoder_output_dim, sum(cfg.num_classes))
|
||||||
|
|
||||||
# modules below are not needed during fine-tuning
|
# modules below are not needed during fine-tuning
|
||||||
|
|||||||
@ -154,9 +154,9 @@ class AsrModel(nn.Module):
|
|||||||
|
|
||||||
ctc_loss = torch.nn.functional.ctc_loss(
|
ctc_loss = torch.nn.functional.ctc_loss(
|
||||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||||
targets=targets,
|
targets=targets.cpu(),
|
||||||
input_lengths=encoder_out_lens,
|
input_lengths=encoder_out_lens.cpu(),
|
||||||
target_lengths=target_lengths,
|
target_lengths=target_lengths.cpu(),
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
)
|
)
|
||||||
return ctc_loss
|
return ctc_loss
|
||||||
|
|||||||
7
egs/librispeech/SSL/zipformer/pretrain.py
Normal file → Executable file
7
egs/librispeech/SSL/zipformer/pretrain.py
Normal file → Executable file
@ -41,7 +41,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -595,7 +594,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-keep-size",
|
"--max-keep-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=sys.maxsize,
|
default=320000,
|
||||||
help="exclude sample longer than this.",
|
help="exclude sample longer than this.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1219,6 +1218,8 @@ def run(rank, world_size, args):
|
|||||||
num_classes=params.num_classes,
|
num_classes=params.num_classes,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
sampler_state_dict=sampler_state_dict,
|
sampler_state_dict=sampler_state_dict,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
@ -1234,6 +1235,8 @@ def run(rank, world_size, args):
|
|||||||
pad_audio=False,
|
pad_audio=False,
|
||||||
num_classes=params.num_classes,
|
num_classes=params.num_classes,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.sanity_check and not params.print_diagnostics:
|
if params.sanity_check and not params.print_diagnostics:
|
||||||
|
|||||||
1
egs/librispeech/SSL/zipformer_ctc/asr_datamodule.py
Symbolic link
1
egs/librispeech/SSL/zipformer_ctc/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/asr_datamodule.py
|
||||||
1
egs/librispeech/SSL/zipformer_ctc/beam_search.py
Symbolic link
1
egs/librispeech/SSL/zipformer_ctc/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/beam_search.py
|
||||||
823
egs/librispeech/SSL/zipformer_ctc/ctc_decode.py
Executable file
823
egs/librispeech/SSL/zipformer_ctc/ctc_decode.py
Executable file
@ -0,0 +1,823 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Liyong Guo,
|
||||||
|
# Quandong Wang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
(1) ctc-decoding
|
||||||
|
./zipformer/ctc_decode.py \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method ctc-decoding
|
||||||
|
|
||||||
|
(2) 1best
|
||||||
|
./zipformer/ctc_decode.py \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--hlg-scale 0.6 \
|
||||||
|
--decoding-method 1best
|
||||||
|
|
||||||
|
(3) nbest
|
||||||
|
./zipformer/ctc_decode.py \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--hlg-scale 0.6 \
|
||||||
|
--decoding-method nbest
|
||||||
|
|
||||||
|
(4) nbest-rescoring
|
||||||
|
./zipformer/ctc_decode.py \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--hlg-scale 0.6 \
|
||||||
|
--nbest-scale 1.0 \
|
||||||
|
--lm-dir data/lm \
|
||||||
|
--decoding-method nbest-rescoring
|
||||||
|
|
||||||
|
(5) whole-lattice-rescoring
|
||||||
|
./zipformer/ctc_decode.py \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--hlg-scale 0.6 \
|
||||||
|
--nbest-scale 1.0 \
|
||||||
|
--lm-dir data/lm \
|
||||||
|
--decoding-method whole-lattice-rescoring
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from finetune_ctc import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.decode import (
|
||||||
|
get_lattice,
|
||||||
|
nbest_decoding,
|
||||||
|
nbest_oracle,
|
||||||
|
one_best_decoding,
|
||||||
|
rescore_with_n_best_list,
|
||||||
|
rescore_with_whole_lattice,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
get_texts,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_char",
|
||||||
|
help="""The lang dir
|
||||||
|
It contains language related input files such as
|
||||||
|
"lexicon.txt"
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="ctc-decoding",
|
||||||
|
help="""Decoding method.
|
||||||
|
Supported values are:
|
||||||
|
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||||
|
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||||
|
It needs neither a lexicon nor an n-gram LM.
|
||||||
|
- (2) 1best. Extract the best path from the decoding lattice as the
|
||||||
|
decoding result.
|
||||||
|
- (3) nbest. Extract n paths from the decoding lattice; the path
|
||||||
|
with the highest score is the decoding result.
|
||||||
|
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||||
|
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||||
|
the highest score is the decoding result.
|
||||||
|
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||||
|
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||||
|
is the decoding result.
|
||||||
|
you have trained an RNN LM using ./rnn_lm/train.py
|
||||||
|
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||||
|
rescoring method can achieve. Useful for debugging n-best
|
||||||
|
rescoring method.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""Number of paths for n-best based decoding method.
|
||||||
|
Used only when "method" is one of the following values:
|
||||||
|
nbest, nbest-rescoring, and nbest-oracle
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="""The scale to be applied to `lattice.scores`.
|
||||||
|
It's needed if you use any kinds of n-best based rescoring.
|
||||||
|
Used only when "method" is one of the following values:
|
||||||
|
nbest, nbest-rescoring, and nbest-oracle
|
||||||
|
A smaller value results in more unique paths.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hlg-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.6,
|
||||||
|
help="""The scale to be applied to `hlg.scores`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lm",
|
||||||
|
help="""The n-gram LM dir.
|
||||||
|
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_decoding_params() -> AttributeDict:
|
||||||
|
"""Parameters for decoding."""
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"frame_shift_ms": 10,
|
||||||
|
"search_beam": 20,
|
||||||
|
"output_beam": 8,
|
||||||
|
"min_active_states": 30,
|
||||||
|
"max_active_states": 10000,
|
||||||
|
"use_double_scores": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
HLG: Optional[k2.Fsa],
|
||||||
|
H: Optional[k2.Fsa],
|
||||||
|
lexicon: Lexicon,
|
||||||
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
|
batch: dict,
|
||||||
|
word_table: k2.SymbolTable,
|
||||||
|
G: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[List[str]]]:
|
||||||
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
|
following format:
|
||||||
|
- key: It indicates the setting used for decoding. For example,
|
||||||
|
if no rescoring is used, the key is the string `no_rescore`.
|
||||||
|
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||||
|
where `xxx` is the value of `lm_scale`. An example key is
|
||||||
|
`lm_scale_0.7`
|
||||||
|
- value: It contains the decoding result. `len(value)` equals to
|
||||||
|
batch size. `value[i]` is the decoding result for the i-th
|
||||||
|
utterance in the given batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
|
||||||
|
- params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
|
||||||
|
- params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
|
||||||
|
- params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
|
||||||
|
- params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
|
||||||
|
rescoring.
|
||||||
|
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
HLG:
|
||||||
|
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
|
||||||
|
H:
|
||||||
|
The ctc topo. Used only when params.decoding_method is ctc-decoding.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
G:
|
||||||
|
An LM. It is not None when params.decoding_method is "nbest-rescoring"
|
||||||
|
or "whole-lattice-rescoring". In general, the G in HLG
|
||||||
|
is a 3-gram LM, while this G is a 4-gram LM.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict. Note: If it decodes to nothing, then return None.
|
||||||
|
"""
|
||||||
|
if HLG is not None:
|
||||||
|
device = HLG.device
|
||||||
|
else:
|
||||||
|
device = H.device
|
||||||
|
|
||||||
|
audio = batch["audio"].to(device)
|
||||||
|
padding_mask = batch["padding_mask"].to(device)
|
||||||
|
encoder_out, encoder_out_lens = model.forward_encoder(audio, padding_mask)
|
||||||
|
ctc_output = model.ctc_output(encoder_out)
|
||||||
|
|
||||||
|
num_frames = encoder_out_lens.cpu()
|
||||||
|
|
||||||
|
supervision_segments = torch.stack(
|
||||||
|
(
|
||||||
|
torch.arange(audio.shape[0], dtype=torch.int32),
|
||||||
|
torch.zeros_like(num_frames, dtype=torch.int32),
|
||||||
|
num_frames,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
).to(torch.int32)
|
||||||
|
|
||||||
|
if H is None:
|
||||||
|
assert HLG is not None
|
||||||
|
decoding_graph = HLG
|
||||||
|
else:
|
||||||
|
assert HLG is None
|
||||||
|
decoding_graph = H
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=ctc_output,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-decoding":
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
||||||
|
# since we are using H, not HLG here.
|
||||||
|
#
|
||||||
|
# token_ids is a lit-of-list of IDs
|
||||||
|
token_ids = get_texts(best_path)
|
||||||
|
|
||||||
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
|
hyps = [
|
||||||
|
"".join(lexicon.token_table[idx].replace("|", " ") for idx in token_id)
|
||||||
|
for token_id in token_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
key = "ctc-decoding"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.decoding_method == "nbest-oracle":
|
||||||
|
# Note: You can also pass rescored lattices to it.
|
||||||
|
# We choose the HLG decoded lattice for speed reasons
|
||||||
|
# as HLG decoding is faster and the oracle WER
|
||||||
|
# is only slightly worse than that of rescored lattices.
|
||||||
|
best_path = nbest_oracle(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=batch["supervisions"]["text"],
|
||||||
|
word_table=word_table,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
oov="<UNK>",
|
||||||
|
)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.decoding_method in ["1best", "nbest"]:
|
||||||
|
if params.decoding_method == "1best":
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
key = "no_rescore"
|
||||||
|
else:
|
||||||
|
best_path = nbest_decoding(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
use_double_scores=params.use_double_scores,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
assert params.decoding_method in [
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
]
|
||||||
|
|
||||||
|
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||||
|
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||||
|
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||||
|
|
||||||
|
if params.decoding_method == "nbest-rescoring":
|
||||||
|
best_path_dict = rescore_with_n_best_list(
|
||||||
|
lattice=lattice,
|
||||||
|
G=G,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
lm_scale_list=lm_scale_list,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "whole-lattice-rescoring":
|
||||||
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=lm_scale_list,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert False, f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
|
||||||
|
ans = dict()
|
||||||
|
if best_path_dict is not None:
|
||||||
|
for lm_scale_str, best_path in best_path_dict.items():
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
ans[lm_scale_str] = hyps
|
||||||
|
else:
|
||||||
|
ans = None
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
HLG: Optional[k2.Fsa],
|
||||||
|
H: Optional[k2.Fsa],
|
||||||
|
lexicon: Lexicon,
|
||||||
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
|
word_table: k2.SymbolTable,
|
||||||
|
G: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
HLG:
|
||||||
|
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
|
||||||
|
H:
|
||||||
|
The ctc topo. Used only when params.decoding_method is ctc-decoding.
|
||||||
|
word_table:
|
||||||
|
It is the word symbol table.
|
||||||
|
G:
|
||||||
|
An LM. It is not None when params.decoding_method is "nbest-rescoring"
|
||||||
|
or "whole-lattice-rescoring". In general, the G in HLG
|
||||||
|
is a 3-gram LM, while this G is a 4-gram LM.
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||||
|
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
results = defaultdict(list)
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["cuts"]]
|
||||||
|
|
||||||
|
hyps_dict = decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
HLG=HLG,
|
||||||
|
H=H,
|
||||||
|
lexicon=lexicon,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
|
batch=batch,
|
||||||
|
word_table=word_table,
|
||||||
|
G=G,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, hyps in hyps_dict.items():
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % 100 == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
|
):
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
args.lm_dir = Path(args.lm_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
# add decoding params
|
||||||
|
params.update(get_decoding_params())
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
assert params.decoding_method in (
|
||||||
|
"ctc-decoding",
|
||||||
|
"1best",
|
||||||
|
"nbest",
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
"nbest-oracle",
|
||||||
|
)
|
||||||
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
|
||||||
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-decoding":
|
||||||
|
HLG = None
|
||||||
|
H = k2.ctc_topo(
|
||||||
|
max_token=max_token_id,
|
||||||
|
modified=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
H = None
|
||||||
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||||
|
)
|
||||||
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
|
HLG.scores *= params.hlg_scale
|
||||||
|
if not hasattr(HLG, "lm_scores"):
|
||||||
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
|
if params.decoding_method in (
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
):
|
||||||
|
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||||
|
logging.info("Loading G_4_gram.fst.txt")
|
||||||
|
logging.warning("It may take 8 minutes.")
|
||||||
|
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||||
|
first_word_disambig_id = lexicon.word_table["#0"]
|
||||||
|
|
||||||
|
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
|
# G.aux_labels is not needed in later computations, so
|
||||||
|
# remove it here.
|
||||||
|
del G.aux_labels
|
||||||
|
# CAUTION: The following line is crucial.
|
||||||
|
# Arcs entering the back-off state have label equal to #0.
|
||||||
|
# We have to change it to 0 here.
|
||||||
|
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||||
|
# See https://github.com/k2-fsa/k2/issues/874
|
||||||
|
# for why we need to set G.properties to None
|
||||||
|
G.__dict__["_properties"] = None
|
||||||
|
G = k2.Fsa.from_fsas([G]).to(device)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
# Save a dummy value so that it can be loaded in C++.
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/67902
|
||||||
|
# for why we need to do this.
|
||||||
|
G.dummy = 1
|
||||||
|
|
||||||
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||||
|
else:
|
||||||
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||||
|
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||||
|
G = k2.Fsa.from_dict(d)
|
||||||
|
|
||||||
|
if params.decoding_method == "whole-lattice-rescoring":
|
||||||
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
# it with the whole lattice later
|
||||||
|
G = k2.add_epsilon_self_loops(G)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
G = G.to(device)
|
||||||
|
|
||||||
|
# G.lm_scores is used to replace HLG.lm_scores during
|
||||||
|
# LM rescoring.
|
||||||
|
G.lm_scores = G.scores.clone()
|
||||||
|
else:
|
||||||
|
G = None
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
|
dev_clean_cuts = librispeech.dev_clean_cuts()
|
||||||
|
dev_other_cuts = librispeech.dev_other_cuts()
|
||||||
|
|
||||||
|
dev_clean_dl = librispeech.test_dataloaders(
|
||||||
|
dev_clean_cuts,
|
||||||
|
do_normalize=params.do_normalize,
|
||||||
|
)
|
||||||
|
dev_other_dl = librispeech.test_dataloaders(
|
||||||
|
dev_other_cuts,
|
||||||
|
do_normalize=params.do_normalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_sets = ["dev-clean", "dev-other"]
|
||||||
|
test_dl = [dev_clean_dl, dev_other_dl]
|
||||||
|
|
||||||
|
# test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
# test_other_cuts = librispeech.test_other_cuts()
|
||||||
|
|
||||||
|
# test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||||
|
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||||
|
|
||||||
|
# test_sets = ["test-clean", "test-other"]
|
||||||
|
# test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
HLG=HLG,
|
||||||
|
H=H,
|
||||||
|
lexicon=lexicon,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
|
word_table=lexicon.word_table,
|
||||||
|
G=G,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=results_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
egs/librispeech/SSL/zipformer_ctc/dataset.py
Symbolic link
1
egs/librispeech/SSL/zipformer_ctc/dataset.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/dataset.py
|
||||||
1
egs/librispeech/SSL/zipformer_ctc/encoder_interface.py
Symbolic link
1
egs/librispeech/SSL/zipformer_ctc/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/encoder_interface.py
|
||||||
262
egs/librilight/SSL/zipformer/finetune.py → egs/librispeech/SSL/zipformer_ctc/finetune_ctc.py
Normal file → Executable file
262
egs/librilight/SSL/zipformer/finetune.py → egs/librispeech/SSL/zipformer_ctc/finetune_ctc.py
Normal file → Executable file
@ -1,12 +1,10 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
# Wei Kang,
|
# Wei Kang,
|
||||||
# Mingshuang Luo,
|
# Mingshuang Luo,
|
||||||
# Zengwei Yao,
|
# Zengwei Yao,
|
||||||
# Yifan Yang,
|
# Yifan Yang,
|
||||||
# Daniel Povey)
|
# Daniel Povey)
|
||||||
#
|
|
||||||
# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
|
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -27,19 +25,14 @@ Usage:
|
|||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||||
|
|
||||||
# For HuBERT model finetuning:
|
# For HuBERT model finetuning:
|
||||||
./hubert/finetune.py \
|
./zipformer_ctc/finetune_ctc.py \
|
||||||
--world-size 8 \
|
--world-size 8 \
|
||||||
--num-epochs 200 \
|
--num-epochs 222 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir hubert/exp \
|
--exp-dir zipformer_ctc/exp \
|
||||||
--full-libri 0 \
|
--full-libri 0 \
|
||||||
--max-duration 1000
|
--max-duration 600
|
||||||
|
|
||||||
It supports finetuning with:
|
|
||||||
- transducer loss (default), with `--use-transducer True --use-ctc False`
|
|
||||||
- ctc loss (not recommended), with `--use-transducer False --use-ctc True`
|
|
||||||
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -58,9 +51,7 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from decoder import Decoder
|
|
||||||
from hubert_ce import HubertModel
|
from hubert_ce import HubertModel
|
||||||
from joiner import Joiner
|
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
@ -72,6 +63,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -81,6 +73,7 @@ from icefall.checkpoint import (
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
@ -415,37 +408,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="use separate projection for each target",
|
help="use separate projection for each target",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decoder-dim",
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="Embedding dimension in the decoder model.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--joiner-dim",
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="""Dimension used in the joiner model.
|
|
||||||
Outputs from the encoder and decoder model are projected
|
|
||||||
to this dimension before adding.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-transducer",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="If True, use Transducer head.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-ctc",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="If True, use CTC head.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -509,6 +471,16 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_char",
|
||||||
|
help="""The lang dir
|
||||||
|
It contains language related input files such as
|
||||||
|
"lexicon.txt"
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pretrained-dir",
|
"--pretrained-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -516,13 +488,6 @@ def get_parser():
|
|||||||
It specifies the directory where the pretrained checkpoint is saved.""",
|
It specifies the directory where the pretrained checkpoint is saved.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--bpe-model",
|
|
||||||
type=str,
|
|
||||||
default="data/lang_bpe_500/bpe.model",
|
|
||||||
help="Path to the BPE model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-lr", type=float, default=0.001, help="The base learning rate."
|
"--base-lr", type=float, default=0.001, help="The base learning rate."
|
||||||
)
|
)
|
||||||
@ -551,53 +516,6 @@ def get_parser():
|
|||||||
"schedules inside the model",
|
"schedules inside the model",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--context-size",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--prune-range",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
help="The prune range for rnnt loss, it means how many symbols(context)"
|
|
||||||
"we are using to compute the loss",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--lm-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.25,
|
|
||||||
help="The scale to smooth the loss with lm "
|
|
||||||
"(output of prediction network) part.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--am-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--simple-loss-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="To get pruning ranges, we will calculate a simple version"
|
|
||||||
"loss(joiner is just addition), this simple loss also uses for"
|
|
||||||
"training (as a regularization item). We will scale the simple loss"
|
|
||||||
"with this parameter before adding to the final loss.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ctc-loss-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.2,
|
|
||||||
help="Scale for CTC loss.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -720,8 +638,6 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||||
|
|
||||||
- warm_step: The warmup period that dictates the decay of the
|
|
||||||
scale on "simple" (un-pruned) loss.
|
|
||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
@ -734,8 +650,6 @@ def get_params() -> AttributeDict:
|
|||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
"valid_interval": 3000, # For the 100h subset, use 800
|
||||||
# parameters for pruned RNN-T loss
|
|
||||||
"warm_step": 2000,
|
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -758,51 +672,12 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|
||||||
decoder = Decoder(
|
|
||||||
vocab_size=params.vocab_size,
|
|
||||||
decoder_dim=params.decoder_dim,
|
|
||||||
blank_id=params.blank_id,
|
|
||||||
context_size=params.context_size,
|
|
||||||
)
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|
||||||
joiner = Joiner(
|
|
||||||
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
|
||||||
decoder_dim=params.decoder_dim,
|
|
||||||
joiner_dim=params.joiner_dim,
|
|
||||||
vocab_size=params.vocab_size,
|
|
||||||
)
|
|
||||||
return joiner
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
assert params.use_transducer or params.use_ctc, (
|
|
||||||
f"At least one of them should be True, "
|
|
||||||
f"but got params.use_transducer={params.use_transducer}, "
|
|
||||||
f"params.use_ctc={params.use_ctc}"
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
|
|
||||||
if params.use_transducer:
|
|
||||||
decoder = get_decoder_model(params)
|
|
||||||
joiner = get_joiner_model(params)
|
|
||||||
else:
|
|
||||||
decoder = None
|
|
||||||
joiner = None
|
|
||||||
|
|
||||||
model = AsrModel(
|
model = AsrModel(
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
|
||||||
joiner=joiner,
|
|
||||||
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||||
decoder_dim=params.decoder_dim,
|
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
use_transducer=params.use_transducer,
|
|
||||||
use_ctc=params.use_ctc,
|
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -926,7 +801,7 @@ def save_checkpoint(
|
|||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
sp: spm.SentencePieceProcessor,
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
@ -953,44 +828,19 @@ def compute_loss(
|
|||||||
padding_mask = batch["padding_mask"].to(device)
|
padding_mask = batch["padding_mask"].to(device)
|
||||||
|
|
||||||
batch_idx_train = params.batch_idx_train
|
batch_idx_train = params.batch_idx_train
|
||||||
warm_step = params.warm_step
|
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
y = sp.encode(texts, out_type=int)
|
y = graph_compiler.texts_to_ids(texts, sep="|")
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss, num_frames = model(
|
ctc_loss, num_frames = model(
|
||||||
x=audio,
|
x=audio,
|
||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
y=y,
|
y=y,
|
||||||
prune_range=params.prune_range,
|
|
||||||
am_scale=params.am_scale,
|
|
||||||
lm_scale=params.lm_scale,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = 0.0
|
assert ctc_loss.requires_grad == is_training
|
||||||
|
|
||||||
if params.use_transducer:
|
|
||||||
s = params.simple_loss_scale
|
|
||||||
# take down the scale on the simple loss from 1.0 at the start
|
|
||||||
# to params.simple_loss scale by warm_step.
|
|
||||||
simple_loss_scale = (
|
|
||||||
s
|
|
||||||
if batch_idx_train >= warm_step
|
|
||||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
|
||||||
)
|
|
||||||
pruned_loss_scale = (
|
|
||||||
1.0
|
|
||||||
if batch_idx_train >= warm_step
|
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
|
||||||
)
|
|
||||||
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
|
||||||
|
|
||||||
if params.use_ctc:
|
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@ -998,20 +848,15 @@ def compute_loss(
|
|||||||
info["frames"] = num_frames.sum().item()
|
info["frames"] = num_frames.sum().item()
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
if params.use_transducer:
|
|
||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
|
||||||
if params.use_ctc:
|
|
||||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
|
||||||
|
|
||||||
return loss, info
|
return ctc_loss, info
|
||||||
|
|
||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
sp: spm.SentencePieceProcessor,
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
) -> MetricsTracker:
|
) -> MetricsTracker:
|
||||||
@ -1024,7 +869,7 @@ def compute_validation_loss(
|
|||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
graph_compiler=graph_compiler,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
)
|
)
|
||||||
@ -1047,7 +892,7 @@ def train_one_epoch(
|
|||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
sp: spm.SentencePieceProcessor,
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
@ -1120,7 +965,7 @@ def train_one_epoch(
|
|||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
graph_compiler=graph_compiler,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
@ -1143,7 +988,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
except: # noqa
|
except: # noqa
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
@ -1232,7 +1077,7 @@ def train_one_epoch(
|
|||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
graph_compiler=graph_compiler,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
@ -1246,7 +1091,7 @@ def train_one_epoch(
|
|||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.accum_grad != params.accum_grad - 1:
|
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
@ -1287,15 +1132,14 @@ def run(rank, world_size, args):
|
|||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
lexicon = Lexicon(params.lang_dir)
|
||||||
sp.load(params.bpe_model)
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
params.vocab_size = sp.get_piece_size()
|
|
||||||
|
|
||||||
if not params.use_transducer:
|
|
||||||
params.ctc_loss_scale = 1.0
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -1388,6 +1232,8 @@ def run(rank, world_size, args):
|
|||||||
train_cuts,
|
train_cuts,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
sampler_state_dict=sampler_state_dict,
|
sampler_state_dict=sampler_state_dict,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
@ -1396,6 +1242,8 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = librispeech.valid_dataloaders(
|
valid_dl = librispeech.valid_dataloaders(
|
||||||
valid_cuts,
|
valid_cuts,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.sanity_check and not params.print_diagnostics:
|
if params.sanity_check and not params.print_diagnostics:
|
||||||
@ -1403,7 +1251,7 @@ def run(rank, world_size, args):
|
|||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sp=sp,
|
graph_compiler=graph_compiler,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1428,7 +1276,7 @@ def run(rank, world_size, args):
|
|||||||
model_avg=model_avg,
|
model_avg=model_avg,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sp=sp,
|
graph_compiler=graph_compiler,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
@ -1462,7 +1310,7 @@ def run(rank, world_size, args):
|
|||||||
def display_and_save_batch(
|
def display_and_save_batch(
|
||||||
batch: dict,
|
batch: dict,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
sp: spm.SentencePieceProcessor,
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Display the batch statistics and save the batch into disk.
|
"""Display the batch statistics and save the batch into disk.
|
||||||
|
|
||||||
@ -1472,8 +1320,6 @@ def display_and_save_batch(
|
|||||||
for the content in it.
|
for the content in it.
|
||||||
params:
|
params:
|
||||||
Parameters for training. See :func:`get_params`.
|
Parameters for training. See :func:`get_params`.
|
||||||
sp:
|
|
||||||
The BPE model.
|
|
||||||
"""
|
"""
|
||||||
from lhotse.utils import uuid4
|
from lhotse.utils import uuid4
|
||||||
|
|
||||||
@ -1484,7 +1330,7 @@ def display_and_save_batch(
|
|||||||
audio = batch["audio"]
|
audio = batch["audio"]
|
||||||
logging.info(f"audio shape: {audio.shape}")
|
logging.info(f"audio shape: {audio.shape}")
|
||||||
|
|
||||||
y = sp.encode(batch["supervisions"]["text"], out_type=int)
|
y = graph_compiler.texts_to_ids(batch["supervisions"]["text"], sep="|")
|
||||||
num_tokens = sum(len(i) for i in y)
|
num_tokens = sum(len(i) for i in y)
|
||||||
logging.info(f"num tokens: {num_tokens}")
|
logging.info(f"num tokens: {num_tokens}")
|
||||||
|
|
||||||
@ -1493,7 +1339,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
@ -1509,7 +1355,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
graph_compiler=graph_compiler,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
@ -1524,7 +1370,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
f"Failing criterion: {criterion} "
|
f"Failing criterion: {criterion} "
|
||||||
f"(={crit_values[criterion]}) ..."
|
f"(={crit_values[criterion]}) ..."
|
||||||
)
|
)
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
||||||
raise
|
raise
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
1
egs/librispeech/SSL/zipformer_ctc/hubert_ce.py
Symbolic link
1
egs/librispeech/SSL/zipformer_ctc/hubert_ce.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/hubert_ce.py
|
||||||
153
egs/librispeech/SSL/zipformer_ctc/model.py
Normal file
153
egs/librispeech/SSL/zipformer_ctc/model.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang,
|
||||||
|
# Zengwei Yao,
|
||||||
|
# Yifan Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
|
|
||||||
|
class AsrModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder,
|
||||||
|
encoder_dim: int = 768,
|
||||||
|
vocab_size: int = 500,
|
||||||
|
):
|
||||||
|
"""CTC ASR model.
|
||||||
|
|
||||||
|
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
It is the transcription network in the paper. Its accepts
|
||||||
|
inputs: `x` of (N, T, encoder_dim).
|
||||||
|
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
||||||
|
`logit_lens` of shape (N,).
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.encoder = encoder
|
||||||
|
|
||||||
|
# Modules for CTC head
|
||||||
|
self.ctc_output = nn.Sequential(
|
||||||
|
nn.Dropout(p=0.1),
|
||||||
|
nn.Linear(encoder_dim, vocab_size),
|
||||||
|
nn.LogSoftmax(dim=-1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_encoder(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute encoder outputs.
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 2-D tensor of shape (N, T).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
encoder_out:
|
||||||
|
Encoder output, of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
Encoder output lengths, of shape (N,).
|
||||||
|
"""
|
||||||
|
if padding_mask is None:
|
||||||
|
padding_mask = torch.zeros_like(x, dtype=torch.bool)
|
||||||
|
|
||||||
|
encoder_out, padding_mask = self.encoder.extract_features(
|
||||||
|
source=x,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
mask=self.encoder.training,
|
||||||
|
)
|
||||||
|
encoder_out_lens = torch.sum(~padding_mask, dim=1)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
def forward_ctc(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
targets: torch.Tensor,
|
||||||
|
target_lengths: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute CTC loss.
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
Encoder output, of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
Encoder output lengths, of shape (N,).
|
||||||
|
targets:
|
||||||
|
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||||
|
to be un-padded and concatenated within 1 dimension.
|
||||||
|
"""
|
||||||
|
# Compute CTC log-prob
|
||||||
|
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||||
|
|
||||||
|
ctc_loss = torch.nn.functional.ctc_loss(
|
||||||
|
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||||
|
targets=targets.cpu(),
|
||||||
|
input_lengths=encoder_out_lens.cpu(),
|
||||||
|
target_lengths=target_lengths.cpu(),
|
||||||
|
reduction="sum",
|
||||||
|
)
|
||||||
|
return ctc_loss
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
y: k2.RaggedTensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 2-D tensor of shape (N, T).
|
||||||
|
y:
|
||||||
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
|
utterance.
|
||||||
|
Returns:
|
||||||
|
Return the CTC loss,
|
||||||
|
"""
|
||||||
|
assert x.ndim == 2, x.shape
|
||||||
|
assert y.num_axes == 2, y.num_axes
|
||||||
|
|
||||||
|
assert x.size(0) == y.dim0, (x.shape, y.dim0)
|
||||||
|
|
||||||
|
# Compute encoder outputs
|
||||||
|
encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask)
|
||||||
|
|
||||||
|
row_splits = y.shape.row_splits(1)
|
||||||
|
y_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
|
||||||
|
# Compute CTC loss
|
||||||
|
targets = y.values
|
||||||
|
ctc_loss = self.forward_ctc(
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
targets=targets,
|
||||||
|
target_lengths=y_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ctc_loss, encoder_out_lens
|
||||||
1
egs/librispeech/SSL/zipformer_ctc/optim.py
Symbolic link
1
egs/librispeech/SSL/zipformer_ctc/optim.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/optim.py
|
||||||
1
egs/librispeech/SSL/zipformer_ctc/scaling.py
Symbolic link
1
egs/librispeech/SSL/zipformer_ctc/scaling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/scaling.py
|
||||||
1
egs/librispeech/SSL/zipformer_ctc/utils.py
Symbolic link
1
egs/librispeech/SSL/zipformer_ctc/utils.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/utils.py
|
||||||
1
egs/librispeech/SSL/zipformer_ctc/wav2vec2_module.py
Symbolic link
1
egs/librispeech/SSL/zipformer_ctc/wav2vec2_module.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/wav2vec2_module.py
|
||||||
1
egs/librispeech/SSL/zipformer_ctc/zipformer.py
Symbolic link
1
egs/librispeech/SSL/zipformer_ctc/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../zipformer/zipformer.py
|
||||||
@ -70,12 +70,15 @@ class CharCtcTrainingGraphCompiler(object):
|
|||||||
Returns:
|
Returns:
|
||||||
Return a list-of-list of token IDs.
|
Return a list-of-list of token IDs.
|
||||||
"""
|
"""
|
||||||
assert sep in ("", "/"), sep
|
assert sep in ("", "/", "|"), sep
|
||||||
ids: List[List[int]] = []
|
ids: List[List[int]] = []
|
||||||
whitespace = re.compile(r"([ \t])")
|
whitespace = re.compile(r"([ \t])")
|
||||||
for text in texts:
|
for text in texts:
|
||||||
if sep == "":
|
if sep == "":
|
||||||
text = re.sub(whitespace, "", text)
|
text = re.sub(whitespace, "", text)
|
||||||
|
elif sep == "|":
|
||||||
|
text = re.sub(r"\s+", " ", text)
|
||||||
|
text = re.sub(" ", "|", text)
|
||||||
else:
|
else:
|
||||||
text = text.split(sep)
|
text = text.split(sep)
|
||||||
sub_ids = [
|
sub_ids = [
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user