This commit is contained in:
yfyeung 2024-04-01 11:09:25 +08:00
parent dfbacbe4dc
commit 3ecf81e793
22 changed files with 8557 additions and 0 deletions

View File

@ -0,0 +1,87 @@
#!/usr/bin/env python3
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
import torch
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_gigaspeech():
in_out_dir = Path("data/fbank")
# number of workers in dataloader
num_workers = 20
# number of seconds in a batch
batch_duration = 1000
subsets = ("L", "M", "S", "XS", "DEV", "TEST")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}")
for partition in subsets:
cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}",
num_workers=num_workers,
batch_duration=batch_duration,
overwrite=True,
)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path)
logging.info(f"Saved to {cuts_path}")
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_gigaspeech()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,160 @@
#!/usr/bin/env python3
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from datetime import datetime
from pathlib import Path
import torch
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--num-workers",
type=int,
default=20,
help="Number of dataloading workers used for reading the audio.",
)
parser.add_argument(
"--batch-duration",
type=float,
default=600.0,
help="The maximum number of audio seconds in a batch."
"Determines batch size dynamically.",
)
parser.add_argument(
"--num-splits",
type=int,
required=True,
help="The number of splits of the XL subset",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Process pieces starting from this number (inclusive).",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop processing pieces until this number (exclusive).",
)
return parser
def compute_fbank_gigaspeech_splits(args):
num_splits = args.num_splits
output_dir = f"data/fbank/XL_split"
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
num_digits = 8 # num_digits is fixed by lhotse split-lazy
start = args.start
stop = args.stop
if stop < start:
stop = num_splits
stop = min(stop, num_splits)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}")
for i in range(start, stop):
idx = f"{i}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}")
cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/gigaspeech_feats_{idx}",
num_workers=args.num_workers,
batch_duration=args.batch_duration,
overwrite=True,
)
logging.info("About to split cuts into smaller chunks.")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path)
logging.info(f"Saved to {cuts_path}")
def main():
now = datetime.now()
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
log_filename = "log-compute_fbank_gigaspeech_splits"
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
log_filename = f"{log_filename}-{date_time}"
logging.basicConfig(
filename=log_filename,
format=formatter,
level=logging.INFO,
filemode="w",
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(logging.Formatter(formatter))
logging.getLogger("").addHandler(console)
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
compute_fbank_gigaspeech_splits(args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,86 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from pathlib import Path
from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import str2bool
def normalize_text(
utt: str,
) -> str:
whitespace_pattern = (re.compile(r"\s\s+"),)
return whitespace_pattern.sub("", utt)
def preprocess_gigaspeech2(args):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
output_dir.mkdir(exist_ok=True)
dataset_parts = ("test",)
logging.info("Loading manifest (may take 4 minutes)")
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix="gigaspeech2",
suffix="jsonl.gz",
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items():
logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"gigaspeech2_cuts_{partition}_raw.jsonl.gz"
if raw_cuts_path.is_file():
logging.info(f"{partition} already exists - skipping")
continue
for sup in m["supervisions"]:
sup.text = normalize_text(sup.text)
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
preprocess_gigaspeech2()
if __name__ == "__main__":
main()

54
egs/gigaspeech2/SSL/prepare.sh Executable file
View File

@ -0,0 +1,54 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=16
# run step 1 to step 5 by default
stage=1
stop_stage=5
# We assume dl_dir (download dir) contains the following directories and files.
#
# - $dl_dir/GigaSpeech2
dl_dir=$PWD/download
lang=Thai
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "Running prepare.sh"
log "dl_dir: $dl_dir"
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare GigaSpeech2 manifest, language: $lang"
# We assume that you have downloaded the GigaSpeech2 corpus
# to $dl_dir/GigaSpeech2
mkdir -p data/manifests
if [ ! -e data/manifests/.gigaspeech2.done ]; then
lhotse prepare gigaspeech2 --lang $lang -j $nj $dl_dir/GigaSpeech2 data/manifests
touch data/manifests/.gigaspeech2.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Compute fbank for gigaspeech2"
mkdir -p data/fbank
if [ ! -e data/fbank/.gigaspeech2.done ]; then
./local/compute_fbank_gigaspeech2.py
touch data/fbank/.gigaspeech2.done
fi
fi

1
egs/gigaspeech2/SSL/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -0,0 +1,287 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2024 Xiaomi Corporation (Author: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from dataset import HubertAsrDataset
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriSpeechAsrDataModule:
"""
DataModule for ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled use 960h LibriSpeech. " "Otherwise, use 100h subset.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/wav"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=float,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--do-normalize",
type=str2bool,
default=True,
help="whether to normalize the data",
)
def train_dataloaders(
self,
cuts_train: CutSet,
do_normalize: bool,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = HubertAsrDataset(do_normalize=do_normalize)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet, do_normalize: bool) -> DataLoader:
logging.info("About to create dev dataset")
validate = HubertAsrDataset(do_normalize=do_normalize)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet, do_normalize: bool) -> DataLoader:
logging.debug("About to create test dataset")
test = HubertAsrDataset(do_normalize=do_normalize)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
train_clean_100_cuts = self.train_clean_100_cuts()
train_clean_360_cuts = self.train_clean_360_cuts()
train_other_500_cuts = self.train_other_500_cuts()
return CutSet.mux(
train_clean_100_cuts,
train_clean_360_cuts,
train_other_500_cuts,
weights=[
28539, # len(train_clean_100_cuts)
104014, # len(train_clean_360_cuts)
148688, # len(train_other_500_cuts)
],
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/beam_search.py

View File

@ -0,0 +1,218 @@
# Copyright 2024 Xiaomi Corporation (authors: Yifan Yang)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Any, Dict, Optional
import numpy as np
import torch
import torch.nn.functional as F
from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.collation import collate_features
from lhotse.workarounds import Hdf5MemoryIssueFix
from torch.utils.data.dataloader import default_collate
class HubertDataset(torch.utils.data.Dataset):
"""
In this implementation, there will always be a single channel.
Returns:
.. code-block::
{
'features': (B, T, F) float tensor
}
Dimension symbols legend:
* ``B`` - batch size (number of Cuts)
* ``T`` - number of frames of the longest Cut
* ``F`` - number of features
"""
def __init__(
self,
max_sample_size: Optional[int] = None,
sample_rate: float = 100,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.label_rate = label_rate
self.random_crop = random_crop
self.pad_feature = pad_audio
self.num_classes = num_classes
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
# This attribute is a workaround to constantly growing HDF5 memory
# throughout the epoch. It regularly closes open file handles to
# reset the internal HDF5 caches.
self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100)
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
self._validate(cuts)
self.hdf5_fix.update()
# Sort the cuts by duration so that the first one determines the batch time dimensions.
cuts = cuts.sort_by_duration(ascending=False)
features = [torch.from_numpy(cut.load_features()) for cut in cuts]
feature_lens = [cut.num_frames for cut in cuts]
if self.pad_feature:
feature_size = min(max(feature_lens), self.max_sample_size)
else:
feature_size = min(min(feature_lens), self.max_sample_size)
features, padding_mask, feature_starts = self.collater_feature(
features, feature_lens, feature_size
)
kmeans = [cut.custom["kmeans"] for cut in cuts]
kmeans = [
torch.tensor([int(item) for item in label.split()], dtype=torch.int64)
for label in kmeans
]
kmeans, kmeans_lens = self.collater_frm_label(kmeans, feature_size, feature_starts)
return {
"cuts": cuts,
"features": features,
"padding_mask": padding_mask,
"kmeans": kmeans,
}
def _validate(self, cuts: CutSet) -> None:
validate(cuts)
assert all(cut.has_recording for cut in cuts)
def crop_to_max_size(self, feature, target_size):
size = len(feature)
diff = size - target_size
if diff <= 0:
return feature, 0
start, end = 0, target_size
if self.random_crop:
start = np.random.randint(0, diff + 1)
end = size - diff + start
return feature[start:end, :], start
def collater_feature(self, features, feature_lens, feature_size):
feature_dim = features[0].shape[-1]
collated_features = features[0].new_zeros(len(features), feature_size, feature_dim)
padding_mask = (
torch.BoolTensor(collated_features.shape[:-1]).fill_(False)
# if self.pad_feature else None
)
feature_starts = [0 for _ in features]
for i, (feature, feature_len) in enumerate(zip(features, feature_lens)):
diff = feature_len - feature_size
if diff == 0:
collated_features[i] = feature
elif diff < 0:
assert self.pad_feature
collated_features[i] = torch.cat([feature, feature.new_full((-diff, feature_dim), 0.0)])
padding_mask[i, diff:] = True
else:
collated_features[i], feature_starts[i] = self.crop_to_max_size(
feature, feature_size
)
return collated_features, padding_mask, feature_starts
def collate_tokens(
self,
values,
pad_idx,
eos_idx=None,
left_pad=False,
move_eos_to_beginning=False,
pad_to_length=None,
pad_to_multiple=1,
pad_to_bsz=None,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = size if pad_to_length is None else max(size, pad_to_length)
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
res = values[0].new(batch_size, size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
if eos_idx is None:
# if no eos_idx is specified, then use the last token in src
dst[0] = src[-1]
else:
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
return res
def collater_frm_label(self, targets, feature_size, feature_starts):
label_rate = self.label_rate
pad = self.num_classes[0] - 1
assert label_rate > 0
s2f = label_rate / self.sample_rate
frm_starts = [int(round(s * s2f)) for s in feature_starts]
frm_size = int(round(feature_size * s2f))
if not self.pad_feature:
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
frm_size = min(frm_size, *rem_size)
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
lengths = torch.LongTensor([len(t) for t in targets])
targets = self.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths
if __name__ == "__main__":
from lhotse import load_manifest_lazy
from lhotse.dataset import DynamicBucketingSampler
from torch.utils.data import DataLoader
dataset = HubertDataset(max_sample_size=1562)
cuts = load_manifest_lazy("data/fbank/librispeech_cuts_train-clean-100.jsonl.gz")
sampler = DynamicBucketingSampler(
cuts,
max_duration=300,
shuffle=False,
)
dl = DataLoader(
dataset,
batch_size=None,
sampler=sampler,
num_workers=0,
)
for batch_idx, batch in enumerate(dl):
print(batch["features"].shape)
print(batch["padding_mask"].shape)
print(batch["kmeans"].shape)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/encoder_interface.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,585 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import argparse
import logging
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from utils import LayerNorm
from zipformer import Zipformer2
def compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[torch.Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
require_same_masks: bool = True,
mask_dropout: float = 0.0,
add_masks: bool = False,
seed: Optional[int] = None,
epoch: Optional[int] = None,
indices: Optional[torch.Tensor] = None,
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_type: how to compute mask lengths
static = fixed size
uniform = sample from uniform distribution [mask_other, mask_length*2]
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
poisson = sample from possion distribution with lambda = mask length
min_masks: minimum number of masked spans
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
mask_dropout: randomly dropout this percentage of masks in each example
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
if num_mask_ver == 1:
all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(bsz):
if seed is not None and epoch is not None and indices is not None:
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
else:
seed_i = None
rng = np.random.default_rng(seed_i)
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
assert sz >= 0, sz
else:
sz = all_sz
if num_mask_ver == 1:
if padding_mask is not None:
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
num_mask = all_num_mask
elif num_mask_ver == 2:
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ rng.random()
)
num_mask = max(min_masks, num_mask)
else:
raise ValueError()
if mask_type == "static":
lengths = np.full(num_mask, mask_length)
elif mask_type == "uniform":
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
elif mask_type == "normal":
lengths = rng.normal(mask_length, mask_other, size=num_mask)
lengths = [max(1, int(round(x))) for x in lengths]
elif mask_type == "poisson":
lengths = rng.poisson(mask_length, size=num_mask)
lengths = [int(round(x)) for x in lengths]
else:
raise Exception("unknown mask selection " + mask_type)
if sum(lengths) == 0:
if mask_type == "static":
raise ValueError(f"this should never happens")
else:
lengths = [min(mask_length, sz - 1)]
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = rng.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = np.fromiter(
(e - s if e - s >= length + min_space else 0 for s, e in parts),
np.int,
)
l_sum = np.sum(lens)
if l_sum == 0:
break
probs = lens / np.sum(lens)
c = rng.choice(len(parts), p=probs)
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = np.asarray(mask_idc)
else:
if idc_select_ver == 1:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
elif idc_select_ver == 2:
mask_idc = rng.choice(sz, num_mask, replace=False)
else:
raise ValueError()
mask_idc = np.asarray(
[
mask_idc[j] + offset
for j in range(len(mask_idc))
for offset in range(lengths[j])
]
)
mask_idc = np.unique(mask_idc[mask_idc < sz])
if len(mask_idc) >= sz:
raise ValueError(
(
f"the entire sequence is masked. "
f"sz={sz}; mask_idc[mask_idc]; "
f"index={indices[i] if indices is not None else None}"
)
)
mask_idcs.append(mask_idc)
target_len = None
if require_same_masks:
if add_masks:
target_len = max([len(m) for m in mask_idcs])
else:
target_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if target_len is not None and len(mask_idc) > target_len:
mask_idc = rng.choice(mask_idc, target_len, replace=False)
mask[i, mask_idc] = True
if target_len is not None and len(mask_idc) < target_len:
unmasked = np.flatnonzero(~mask[i])
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
mask[i, to_mask] = True
if mask_dropout > 0:
masked = np.flatnonzero(mask[i])
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
to_drop = rng.choice(masked, num_holes, replace=False)
mask[i, to_drop] = False
return mask
def _to_int_tuple(s: str):
return tuple(map(int, s.split(",")))
class HubertModel(nn.Module):
def __init__(
self,
cfg,
) -> None:
super().__init__()
self.embed = feature_enc_layers[-1][0]
self.encoder_embed = Conv2dSubsampling(
in_channels=cfg.feature_dim,
out_channels=_to_int_tuple(cfg.encoder_dim)[0],
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
)
self.feat2tar_ratio = (
cfg.label_rate * feature_ds_rate / cfg.sample_rate
) # TODO feature_ds_rate 320
encoder_input_dim = _to_int_tuple(cfg.encoder_dim)[0]
encoder_output_dim = max(_to_int_tuple(cfg.encoder_dim))
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.logit_temp = cfg.logit_temp
self.skip_masked = cfg.skip_masked
self.skip_nomask = cfg.skip_nomask
self.mask_emb = nn.Parameter(torch.FloatTensor(encoder_input_dim).uniform_())
self.encoder = Zipformer2(
output_downsampling_factor=1,
downsampling_factor=_to_int_tuple(cfg.downsampling_factor),
num_encoder_layers=_to_int_tuple(cfg.num_encoder_layers),
encoder_dim=_to_int_tuple(cfg.encoder_dim),
encoder_unmasked_dim=_to_int_tuple(cfg.encoder_unmasked_dim),
query_head_dim=_to_int_tuple(cfg.query_head_dim),
pos_head_dim=_to_int_tuple(cfg.pos_head_dim),
value_head_dim=_to_int_tuple(cfg.value_head_dim),
pos_dim=cfg.pos_dim,
num_heads=_to_int_tuple(cfg.num_heads),
feedforward_dim=_to_int_tuple(cfg.feedforward_dim),
cnn_module_kernel=_to_int_tuple(cfg.cnn_module_kernel),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
)
self.layer_norm = LayerNorm(self.embed)
self.untie_final_proj = cfg.untie_final_proj
self.final_proj = nn.Linear(encoder_output_dim, sum(cfg.num_classes))
# modules below are not needed during fine-tuning
self.num_classes = cfg.num_classes
self.pred_masked_weight = cfg.pred_masked_weight
self.pred_nomask_weight = cfg.pred_nomask_weight
self.loss_weights = cfg.loss_weights
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
super().upgrade_state_dict_named(state_dict, name)
return state_dict
def apply_mask(self, x, padding_mask, target_list):
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb.to(x.dtype)
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
return x, mask_indices
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
features = self.encoder_embed(source)
return features
def forward_targets(
self,
features: torch.Tensor,
target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz = features.size(2)
targ_tsz = min([t.size(1) for t in target_list])
if self.feat2tar_ratio * feat_tsz > targ_tsz:
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
features = features[..., :feat_tsz]
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
target_list = [t[:, target_inds.long()] for t in target_list]
return features, target_list
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask
def forward(
self,
source: torch.Tensor,
target_list: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
):
"""output layer is 1-based"""
features = self.forward_features(source)
if target_list is not None:
features, target_list = self.forward_targets(features, target_list)
features_pen = features.float().pow(2).mean()
features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
if mask:
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
else:
x = features
mask_indices = None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float -> (T, B, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x = x.transpose(0, 1)
x, x_lens = self.encoder(x, ~padding_mask.sum(dim=-1))
x = x.transpose(0, 1)
if features_only:
return {"x": x, "padding_mask": padding_mask, "features": features}
if not self.skip_masked:
masked_indices = torch.logical_and(~padding_mask, mask_indices)
proj_x_m = self.final_proj(x[masked_indices])
proj_x_m /= self.logit_temp
logit_m_list = [proj_x_m for _ in range(len(target_list))]
else:
logit_m_list = [None for _ in target_list]
if not self.skip_nomask:
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
proj_x_u = self.final_proj(x[nomask_indices])
proj_x_u /= self.logit_temp
logit_u_list = [proj_x_u for _ in range(len(target_list))]
else:
logit_u_list = [None for _ in target_list]
# result = {
# "logit_m_list": logit_m_list,
# "logit_u_list": logit_u_list,
# "padding_mask": padding_mask,
# "features_pen": features_pen,
# }
targ_m_list = target_list[0][masked_indices]
targ_m_list = targ_m_list.long()
targ_m_list = [targ_m_list for _ in range(len(target_list))]
targ_u_list = target_list[0][nomask_indices]
targ_u_list = targ_u_list.long()
targ_u_list = [targ_u_list for _ in range(len(target_list))]
return self.compute_loss(
logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
)
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
res = self.forward(
source,
padding_mask=padding_mask,
mask=mask,
features_only=True,
output_layer=output_layer,
)
feature = res["features"] if ret_conv else res["x"]
return feature, res["padding_mask"]
def get_logits(self, net_output, is_masked=True):
if is_masked:
logits_list = net_output["logit_m_list"]
else:
logits_list = net_output["logit_u_list"]
logits_list = [x.float() for x in logits_list if x is not None]
return logits_list
def get_targets(self, net_output, is_masked=True):
logits_list = self.get_logits(net_output, is_masked)
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
return targets_list
def get_extra_losses(self, net_output):
extra_losses = []
names = []
if "features_pen" in net_output:
extra_losses.append(net_output["features_pen"])
names.append("features_pen")
return extra_losses, names
def remove_pretraining_modules(self):
self.final_proj = None
def compute_loss(
self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
):
loss = 0.0
sample_size = 0
logging_output = {}
reduce = True
reduction = "sum" if reduce else "none"
loss_m_list = []
logp_m_list = [x.float() for x in logit_m_list if x is not None]
logp_m_list = torch.cat(logp_m_list)
targ_m_list = torch.cat(targ_m_list)
loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction)
loss_m_list.append(loss_m)
logging_output[f"loss_m_0"] = loss_m.detach().item()
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
if self.pred_masked_weight > 0:
loss += self.pred_masked_weight * sum(loss_m_list)
sample_size += len(targ_m_list)
loss_u_list = []
logp_u_list = [x.float() for x in logit_u_list if x is not None]
logp_u_list = torch.cat(logp_u_list)
targ_u_list = torch.cat(targ_u_list)
loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction)
loss_u_list.append(loss_u)
logging_output[f"loss_u_0"] = loss_u.detach().item()
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
if self.pred_nomask_weight > 0:
loss += self.pred_nomask_weight * sum(loss_u_list)
sample_size += len(targ_u_list)
if self.loss_weights is not None:
extra_losses = []
names = []
extra_losses.append(features_pen)
names.append("features_pen")
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
logging_output[f"loss_{n}"] = p.item()
logging_output = {
"loss": loss.item() if reduce else loss,
**logging_output,
}
# for lk in self.log_keys:
# if lk in net_output:
# logging_output[lk] = float((net_output[lk]))
def compute_correct(logits, target):
if logits.numel() == 0:
return 0, 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == target
min = logits.argmin(-1) == target
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
with torch.no_grad():
corr_m, count_m = compute_correct(logp_m_list, targ_m_list)
logging_output[f"correct_m_0"] = corr_m
logging_output[f"count_m_0"] = count_m
corr_u, count_u = compute_correct(logp_u_list, targ_u_list)
logging_output[f"correct_u_0"] = corr_u
logging_output[f"count_u_0"] = count_u
return loss, sample_size, logging_output

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/joiner.py

View File

@ -0,0 +1,344 @@
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Zengwei Yao,
# Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
from scaling import ScaledLinear
from icefall.utils import add_sos
class AsrModel(nn.Module):
def __init__(
self,
encoder,
decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None,
encoder_dim: int = 768,
decoder_dim: int = 512,
vocab_size: int = 500,
use_transducer: bool = True,
use_ctc: bool = False,
):
"""A joint CTC & Transducer ASR model.
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
Args:
encoder:
It is the transcription network in the paper. Its accepts
inputs: `x` of (N, T, encoder_dim).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
It is used when use_transducer is True.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
It is used when use_transducer is True.
use_transducer:
Whether use transducer head. Default: True.
use_ctc:
Whether use CTC head. Default: False.
"""
super().__init__()
assert (
use_transducer or use_ctc
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
self.encoder = encoder
self.use_transducer = use_transducer
if use_transducer:
# Modules for Transducer head
assert decoder is not None
assert hasattr(decoder, "blank_id")
assert joiner is not None
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_scale=0.25
)
self.simple_lm_proj = ScaledLinear(
decoder_dim, vocab_size, initial_scale=0.25
)
else:
assert decoder is None
assert joiner is None
self.use_ctc = use_ctc
if use_ctc:
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
def forward_encoder(
self,
x: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 2-D tensor of shape (N, T).
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
if padding_mask is None:
padding_mask = torch.zeros_like(x, dtype=torch.bool)
encoder_out, padding_mask = self.encoder.extract_features(
source=x,
padding_mask=padding_mask,
mask=self.encoder.training,
)
encoder_out_lens = torch.sum(~padding_mask, dim=1)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
return encoder_out, encoder_out_lens
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=encoder_out_lens,
target_lengths=target_lengths,
reduction="sum",
)
return ctc_loss
def forward_transducer(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
y: k2.RaggedTensor,
y_lens: torch.Tensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
"""
# Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return simple_loss, pruned_loss
def forward(
self,
x: torch.Tensor,
y: k2.RaggedTensor,
padding_mask: Optional[torch.Tensor] = None,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 2-D tensor of shape (N, T).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 2, x.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == y.dim0, (x.shape, y.dim0)
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
if self.use_transducer:
# Compute transducer loss
simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y.to(x.device),
y_lens=y_lens,
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
)
else:
simple_loss = torch.empty(0)
pruned_loss = torch.empty(0)
if self.use_ctc:
# Compute CTC loss
targets = y.values
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
else:
ctc_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss, encoder_out_lens

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/optim.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/scaling.py

View File

@ -0,0 +1,341 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from dataset import HubertDataset
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriSpeechDataModule:
"""
DataModule for SSL experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in SSL
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
This class should be derived for specific corpora used in SSL tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="SSL data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled use 960h LibriSpeech. " "Otherwise, use 100h subset.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/kmeans"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=float,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--do-normalize",
type=str2bool,
default=True,
help="whether to normalize the data",
)
group.add_argument(
"--random-crop",
type=str2bool,
default=True,
help="audio sample rate",
)
def train_dataloaders(
self,
cuts_train: CutSet,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = HubertDataset(
max_sample_size=max_sample_size,
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
pad_audio=pad_audio,
num_classes=num_classes,
do_normalize=do_normalize,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(
self,
cuts_valid: CutSet,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
) -> DataLoader:
logging.info("About to create dev dataset")
validate = HubertDataset(
max_sample_size=max_sample_size,
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
pad_audio=pad_audio,
num_classes=num_classes,
do_normalize=do_normalize,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(
self,
cuts: CutSet,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
) -> DataLoader:
logging.debug("About to create test dataset")
test = HubertDataset(
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
pad_audio=pad_audio,
num_classes=num_classes,
do_normalize=do_normalize,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
train_clean_100_cuts = self.train_clean_100_cuts()
train_clean_360_cuts = self.train_clean_360_cuts()
train_other_500_cuts = self.train_other_500_cuts()
return CutSet.mux(
train_clean_100_cuts,
train_clean_360_cuts,
train_other_500_cuts,
weights=[
28539, # len(train_clean_100_cuts)
104014, # len(train_clean_360_cuts)
148688, # len(train_other_500_cuts)
],
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/utils.py

File diff suppressed because it is too large Load Diff