mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
init
This commit is contained in:
parent
dfbacbe4dc
commit
3ecf81e793
87
egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech.py
Executable file
87
egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech.py
Executable file
@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
|
||||
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_gigaspeech():
|
||||
in_out_dir = Path("data/fbank")
|
||||
# number of workers in dataloader
|
||||
num_workers = 20
|
||||
|
||||
# number of seconds in a batch
|
||||
batch_duration = 1000
|
||||
|
||||
subsets = ("L", "M", "S", "XS", "DEV", "TEST")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
for partition in subsets:
|
||||
cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
continue
|
||||
|
||||
raw_cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
|
||||
|
||||
logging.info(f"Loading {raw_cuts_path}")
|
||||
cut_set = CutSet.from_file(raw_cuts_path)
|
||||
|
||||
logging.info("Computing features")
|
||||
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}",
|
||||
num_workers=num_workers,
|
||||
batch_duration=batch_duration,
|
||||
overwrite=True,
|
||||
)
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
logging.info(f"Saved to {cuts_path}")
|
||||
|
||||
|
||||
def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_gigaspeech()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
160
egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech_splits.py
Executable file
160
egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech_splits.py
Executable file
@ -0,0 +1,160 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
|
||||
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of dataloading workers used for reading the audio.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-duration",
|
||||
type=float,
|
||||
default=600.0,
|
||||
help="The maximum number of audio seconds in a batch."
|
||||
"Determines batch size dynamically.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-splits",
|
||||
type=int,
|
||||
required=True,
|
||||
help="The number of splits of the XL subset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Process pieces starting from this number (inclusive).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--stop",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Stop processing pieces until this number (exclusive).",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def compute_fbank_gigaspeech_splits(args):
|
||||
num_splits = args.num_splits
|
||||
output_dir = f"data/fbank/XL_split"
|
||||
output_dir = Path(output_dir)
|
||||
assert output_dir.exists(), f"{output_dir} does not exist!"
|
||||
|
||||
num_digits = 8 # num_digits is fixed by lhotse split-lazy
|
||||
|
||||
start = args.start
|
||||
stop = args.stop
|
||||
if stop < start:
|
||||
stop = num_splits
|
||||
|
||||
stop = min(stop, num_splits)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
for i in range(start, stop):
|
||||
idx = f"{i}".zfill(num_digits)
|
||||
logging.info(f"Processing {idx}/{num_splits}")
|
||||
|
||||
cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
continue
|
||||
|
||||
raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz"
|
||||
|
||||
logging.info(f"Loading {raw_cuts_path}")
|
||||
cut_set = CutSet.from_file(raw_cuts_path)
|
||||
|
||||
logging.info("Computing features")
|
||||
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/gigaspeech_feats_{idx}",
|
||||
num_workers=args.num_workers,
|
||||
batch_duration=args.batch_duration,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
logging.info("About to split cuts into smaller chunks.")
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
logging.info(f"Saved to {cuts_path}")
|
||||
|
||||
|
||||
def main():
|
||||
now = datetime.now()
|
||||
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
log_filename = "log-compute_fbank_gigaspeech_splits"
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
log_filename = f"{log_filename}-{date_time}"
|
||||
|
||||
logging.basicConfig(
|
||||
filename=log_filename,
|
||||
format=formatter,
|
||||
level=logging.INFO,
|
||||
filemode="w",
|
||||
)
|
||||
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.INFO)
|
||||
console.setFormatter(logging.Formatter(formatter))
|
||||
logging.getLogger("").addHandler(console)
|
||||
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
compute_fbank_gigaspeech_splits(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
86
egs/gigaspeech2/SSL/local/preprocess_gigaspeech2.py
Executable file
86
egs/gigaspeech2/SSL/local/preprocess_gigaspeech2.py
Executable file
@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, SupervisionSegment
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def normalize_text(
|
||||
utt: str,
|
||||
) -> str:
|
||||
whitespace_pattern = (re.compile(r"\s\s+"),)
|
||||
return whitespace_pattern.sub("", utt)
|
||||
|
||||
|
||||
def preprocess_gigaspeech2(args):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
dataset_parts = ("test",)
|
||||
|
||||
logging.info("Loading manifest (may take 4 minutes)")
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix="gigaspeech2",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
for partition, m in manifests.items():
|
||||
logging.info(f"Processing {partition}")
|
||||
raw_cuts_path = output_dir / f"gigaspeech2_cuts_{partition}_raw.jsonl.gz"
|
||||
if raw_cuts_path.is_file():
|
||||
logging.info(f"{partition} already exists - skipping")
|
||||
continue
|
||||
|
||||
for sup in m["supervisions"]:
|
||||
sup.text = normalize_text(sup.text)
|
||||
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {raw_cuts_path}")
|
||||
cut_set.to_file(raw_cuts_path)
|
||||
|
||||
|
||||
def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
preprocess_gigaspeech2()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
54
egs/gigaspeech2/SSL/prepare.sh
Executable file
54
egs/gigaspeech2/SSL/prepare.sh
Executable file
@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
nj=16
|
||||
# run step 1 to step 5 by default
|
||||
stage=1
|
||||
stop_stage=5
|
||||
|
||||
# We assume dl_dir (download dir) contains the following directories and files.
|
||||
#
|
||||
# - $dl_dir/GigaSpeech2
|
||||
|
||||
dl_dir=$PWD/download
|
||||
lang=Thai
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "Running prepare.sh"
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare GigaSpeech2 manifest, language: $lang"
|
||||
# We assume that you have downloaded the GigaSpeech2 corpus
|
||||
# to $dl_dir/GigaSpeech2
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.gigaspeech2.done ]; then
|
||||
lhotse prepare gigaspeech2 --lang $lang -j $nj $dl_dir/GigaSpeech2 data/manifests
|
||||
touch data/manifests/.gigaspeech2.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Compute fbank for gigaspeech2"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.gigaspeech2.done ]; then
|
||||
./local/compute_fbank_gigaspeech2.py
|
||||
touch data/fbank/.gigaspeech2.done
|
||||
fi
|
||||
fi
|
1
egs/gigaspeech2/SSL/shared
Symbolic link
1
egs/gigaspeech2/SSL/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared/
|
287
egs/gigaspeech2/SSL/zipformer/asr_datamodule.py
Normal file
287
egs/gigaspeech2/SSL/zipformer/asr_datamodule.py
Normal file
@ -0,0 +1,287 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2024 Xiaomi Corporation (Author: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from dataset import HubertAsrDataset
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule:
|
||||
"""
|
||||
DataModule for ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled use 960h LibriSpeech. " "Otherwise, use 100h subset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/wav"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=float,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--do-normalize",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="whether to normalize the data",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
do_normalize: bool,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = HubertAsrDataset(do_normalize=do_normalize)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet, do_normalize: bool) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
validate = HubertAsrDataset(do_normalize=do_normalize)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet, do_normalize: bool) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = HubertAsrDataset(do_normalize=do_normalize)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_360_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-360 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_other_500_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-other-500 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_all_shuf_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled train-clean-100, \
|
||||
train-clean-360 and train-other-500 cuts"
|
||||
)
|
||||
train_clean_100_cuts = self.train_clean_100_cuts()
|
||||
train_clean_360_cuts = self.train_clean_360_cuts()
|
||||
train_other_500_cuts = self.train_other_500_cuts()
|
||||
return CutSet.mux(
|
||||
train_clean_100_cuts,
|
||||
train_clean_360_cuts,
|
||||
train_other_500_cuts,
|
||||
weights=[
|
||||
28539, # len(train_clean_100_cuts)
|
||||
104014, # len(train_clean_360_cuts)
|
||||
148688, # len(train_other_500_cuts)
|
||||
],
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||
)
|
1
egs/gigaspeech2/SSL/zipformer/beam_search.py
Symbolic link
1
egs/gigaspeech2/SSL/zipformer/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/beam_search.py
|
218
egs/gigaspeech2/SSL/zipformer/dataset.py
Normal file
218
egs/gigaspeech2/SSL/zipformer/dataset.py
Normal file
@ -0,0 +1,218 @@
|
||||
# Copyright 2024 Xiaomi Corporation (authors: Yifan Yang)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from lhotse import validate
|
||||
from lhotse.cut import CutSet
|
||||
from lhotse.dataset.collation import collate_features
|
||||
from lhotse.workarounds import Hdf5MemoryIssueFix
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
|
||||
class HubertDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
In this implementation, there will always be a single channel.
|
||||
|
||||
Returns:
|
||||
|
||||
.. code-block::
|
||||
|
||||
{
|
||||
'features': (B, T, F) float tensor
|
||||
}
|
||||
|
||||
Dimension symbols legend:
|
||||
* ``B`` - batch size (number of Cuts)
|
||||
* ``T`` - number of frames of the longest Cut
|
||||
* ``F`` - number of features
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_sample_size: Optional[int] = None,
|
||||
sample_rate: float = 100,
|
||||
label_rate: float = 50,
|
||||
random_crop: bool = True,
|
||||
pad_audio: bool = False,
|
||||
num_classes: list = [504],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.sample_rate = sample_rate
|
||||
self.label_rate = label_rate
|
||||
self.random_crop = random_crop
|
||||
self.pad_feature = pad_audio
|
||||
self.num_classes = num_classes
|
||||
self.max_sample_size = (
|
||||
max_sample_size if max_sample_size is not None else sys.maxsize
|
||||
)
|
||||
|
||||
# This attribute is a workaround to constantly growing HDF5 memory
|
||||
# throughout the epoch. It regularly closes open file handles to
|
||||
# reset the internal HDF5 caches.
|
||||
self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100)
|
||||
|
||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
||||
self._validate(cuts)
|
||||
self.hdf5_fix.update()
|
||||
|
||||
# Sort the cuts by duration so that the first one determines the batch time dimensions.
|
||||
cuts = cuts.sort_by_duration(ascending=False)
|
||||
|
||||
features = [torch.from_numpy(cut.load_features()) for cut in cuts]
|
||||
feature_lens = [cut.num_frames for cut in cuts]
|
||||
|
||||
if self.pad_feature:
|
||||
feature_size = min(max(feature_lens), self.max_sample_size)
|
||||
else:
|
||||
feature_size = min(min(feature_lens), self.max_sample_size)
|
||||
|
||||
features, padding_mask, feature_starts = self.collater_feature(
|
||||
features, feature_lens, feature_size
|
||||
)
|
||||
|
||||
kmeans = [cut.custom["kmeans"] for cut in cuts]
|
||||
kmeans = [
|
||||
torch.tensor([int(item) for item in label.split()], dtype=torch.int64)
|
||||
for label in kmeans
|
||||
]
|
||||
kmeans, kmeans_lens = self.collater_frm_label(kmeans, feature_size, feature_starts)
|
||||
|
||||
return {
|
||||
"cuts": cuts,
|
||||
"features": features,
|
||||
"padding_mask": padding_mask,
|
||||
"kmeans": kmeans,
|
||||
}
|
||||
|
||||
def _validate(self, cuts: CutSet) -> None:
|
||||
validate(cuts)
|
||||
assert all(cut.has_recording for cut in cuts)
|
||||
|
||||
def crop_to_max_size(self, feature, target_size):
|
||||
size = len(feature)
|
||||
diff = size - target_size
|
||||
if diff <= 0:
|
||||
return feature, 0
|
||||
|
||||
start, end = 0, target_size
|
||||
if self.random_crop:
|
||||
start = np.random.randint(0, diff + 1)
|
||||
end = size - diff + start
|
||||
return feature[start:end, :], start
|
||||
|
||||
def collater_feature(self, features, feature_lens, feature_size):
|
||||
feature_dim = features[0].shape[-1]
|
||||
collated_features = features[0].new_zeros(len(features), feature_size, feature_dim)
|
||||
padding_mask = (
|
||||
torch.BoolTensor(collated_features.shape[:-1]).fill_(False)
|
||||
# if self.pad_feature else None
|
||||
)
|
||||
feature_starts = [0 for _ in features]
|
||||
for i, (feature, feature_len) in enumerate(zip(features, feature_lens)):
|
||||
diff = feature_len - feature_size
|
||||
if diff == 0:
|
||||
collated_features[i] = feature
|
||||
elif diff < 0:
|
||||
assert self.pad_feature
|
||||
collated_features[i] = torch.cat([feature, feature.new_full((-diff, feature_dim), 0.0)])
|
||||
padding_mask[i, diff:] = True
|
||||
else:
|
||||
collated_features[i], feature_starts[i] = self.crop_to_max_size(
|
||||
feature, feature_size
|
||||
)
|
||||
return collated_features, padding_mask, feature_starts
|
||||
|
||||
def collate_tokens(
|
||||
self,
|
||||
values,
|
||||
pad_idx,
|
||||
eos_idx=None,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
pad_to_length=None,
|
||||
pad_to_multiple=1,
|
||||
pad_to_bsz=None,
|
||||
):
|
||||
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
||||
size = max(v.size(0) for v in values)
|
||||
size = size if pad_to_length is None else max(size, pad_to_length)
|
||||
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
|
||||
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
|
||||
|
||||
batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
|
||||
res = values[0].new(batch_size, size).fill_(pad_idx)
|
||||
|
||||
def copy_tensor(src, dst):
|
||||
assert dst.numel() == src.numel()
|
||||
if move_eos_to_beginning:
|
||||
if eos_idx is None:
|
||||
# if no eos_idx is specified, then use the last token in src
|
||||
dst[0] = src[-1]
|
||||
else:
|
||||
dst[0] = eos_idx
|
||||
dst[1:] = src[:-1]
|
||||
else:
|
||||
dst.copy_(src)
|
||||
|
||||
for i, v in enumerate(values):
|
||||
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
|
||||
return res
|
||||
|
||||
def collater_frm_label(self, targets, feature_size, feature_starts):
|
||||
label_rate = self.label_rate
|
||||
pad = self.num_classes[0] - 1
|
||||
assert label_rate > 0
|
||||
s2f = label_rate / self.sample_rate
|
||||
frm_starts = [int(round(s * s2f)) for s in feature_starts]
|
||||
frm_size = int(round(feature_size * s2f))
|
||||
if not self.pad_feature:
|
||||
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
||||
frm_size = min(frm_size, *rem_size)
|
||||
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
|
||||
|
||||
lengths = torch.LongTensor([len(t) for t in targets])
|
||||
targets = self.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
||||
return targets, lengths
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lhotse import load_manifest_lazy
|
||||
from lhotse.dataset import DynamicBucketingSampler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
dataset = HubertDataset(max_sample_size=1562)
|
||||
cuts = load_manifest_lazy("data/fbank/librispeech_cuts_train-clean-100.jsonl.gz")
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=300,
|
||||
shuffle=False,
|
||||
)
|
||||
dl = DataLoader(
|
||||
dataset,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
print(batch["features"].shape)
|
||||
print(batch["padding_mask"].shape)
|
||||
print(batch["kmeans"].shape)
|
1045
egs/gigaspeech2/SSL/zipformer/decode.py
Normal file
1045
egs/gigaspeech2/SSL/zipformer/decode.py
Normal file
File diff suppressed because it is too large
Load Diff
1
egs/gigaspeech2/SSL/zipformer/decoder.py
Symbolic link
1
egs/gigaspeech2/SSL/zipformer/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/decoder.py
|
1
egs/gigaspeech2/SSL/zipformer/encoder_interface.py
Symbolic link
1
egs/gigaspeech2/SSL/zipformer/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/encoder_interface.py
|
1552
egs/gigaspeech2/SSL/zipformer/finetune.py
Normal file
1552
egs/gigaspeech2/SSL/zipformer/finetune.py
Normal file
File diff suppressed because it is too large
Load Diff
585
egs/gigaspeech2/SSL/zipformer/hubert_ce.py
Normal file
585
egs/gigaspeech2/SSL/zipformer/hubert_ce.py
Normal file
@ -0,0 +1,585 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scaling import ScheduledFloat
|
||||
from subsampling import Conv2dSubsampling
|
||||
from utils import LayerNorm
|
||||
from zipformer import Zipformer2
|
||||
|
||||
|
||||
def compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
mask_type: str = "static",
|
||||
mask_other: float = 0.0,
|
||||
min_masks: int = 0,
|
||||
no_overlap: bool = False,
|
||||
min_space: int = 0,
|
||||
require_same_masks: bool = True,
|
||||
mask_dropout: float = 0.0,
|
||||
add_masks: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
epoch: Optional[int] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
|
||||
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes random mask spans for a given shape
|
||||
|
||||
Args:
|
||||
shape: the the shape for which to compute masks.
|
||||
should be of size 2 where first element is batch size and 2nd is timesteps
|
||||
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||||
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||||
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||
mask_type: how to compute mask lengths
|
||||
static = fixed size
|
||||
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
||||
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
||||
poisson = sample from possion distribution with lambda = mask length
|
||||
min_masks: minimum number of masked spans
|
||||
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
||||
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
||||
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
||||
mask_dropout: randomly dropout this percentage of masks in each example
|
||||
"""
|
||||
|
||||
bsz, all_sz = shape
|
||||
mask = np.full((bsz, all_sz), False)
|
||||
|
||||
if num_mask_ver == 1:
|
||||
all_num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * all_sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
all_num_mask = max(min_masks, all_num_mask)
|
||||
|
||||
mask_idcs = []
|
||||
for i in range(bsz):
|
||||
if seed is not None and epoch is not None and indices is not None:
|
||||
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
|
||||
else:
|
||||
seed_i = None
|
||||
|
||||
rng = np.random.default_rng(seed_i)
|
||||
|
||||
if padding_mask is not None:
|
||||
sz = all_sz - padding_mask[i].long().sum().item()
|
||||
assert sz >= 0, sz
|
||||
else:
|
||||
sz = all_sz
|
||||
|
||||
if num_mask_ver == 1:
|
||||
if padding_mask is not None:
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
num_mask = all_num_mask
|
||||
elif num_mask_ver == 2:
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ rng.random()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
if mask_type == "static":
|
||||
lengths = np.full(num_mask, mask_length)
|
||||
elif mask_type == "uniform":
|
||||
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
||||
elif mask_type == "normal":
|
||||
lengths = rng.normal(mask_length, mask_other, size=num_mask)
|
||||
lengths = [max(1, int(round(x))) for x in lengths]
|
||||
elif mask_type == "poisson":
|
||||
lengths = rng.poisson(mask_length, size=num_mask)
|
||||
lengths = [int(round(x)) for x in lengths]
|
||||
else:
|
||||
raise Exception("unknown mask selection " + mask_type)
|
||||
|
||||
if sum(lengths) == 0:
|
||||
if mask_type == "static":
|
||||
raise ValueError(f"this should never happens")
|
||||
else:
|
||||
lengths = [min(mask_length, sz - 1)]
|
||||
|
||||
if no_overlap:
|
||||
mask_idc = []
|
||||
|
||||
def arrange(s, e, length, keep_length):
|
||||
span_start = rng.randint(s, e - length)
|
||||
mask_idc.extend(span_start + i for i in range(length))
|
||||
|
||||
new_parts = []
|
||||
if span_start - s - min_space >= keep_length:
|
||||
new_parts.append((s, span_start - min_space + 1))
|
||||
if e - span_start - length - min_space > keep_length:
|
||||
new_parts.append((span_start + length + min_space, e))
|
||||
return new_parts
|
||||
|
||||
parts = [(0, sz)]
|
||||
min_length = min(lengths)
|
||||
for length in sorted(lengths, reverse=True):
|
||||
lens = np.fromiter(
|
||||
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
||||
np.int,
|
||||
)
|
||||
l_sum = np.sum(lens)
|
||||
if l_sum == 0:
|
||||
break
|
||||
probs = lens / np.sum(lens)
|
||||
c = rng.choice(len(parts), p=probs)
|
||||
s, e = parts.pop(c)
|
||||
parts.extend(arrange(s, e, length, min_length))
|
||||
mask_idc = np.asarray(mask_idc)
|
||||
else:
|
||||
if idc_select_ver == 1:
|
||||
min_len = min(lengths)
|
||||
if sz - min_len <= num_mask:
|
||||
min_len = sz - num_mask - 1
|
||||
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
|
||||
elif idc_select_ver == 2:
|
||||
mask_idc = rng.choice(sz, num_mask, replace=False)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
mask_idc = np.asarray(
|
||||
[
|
||||
mask_idc[j] + offset
|
||||
for j in range(len(mask_idc))
|
||||
for offset in range(lengths[j])
|
||||
]
|
||||
)
|
||||
|
||||
mask_idc = np.unique(mask_idc[mask_idc < sz])
|
||||
if len(mask_idc) >= sz:
|
||||
raise ValueError(
|
||||
(
|
||||
f"the entire sequence is masked. "
|
||||
f"sz={sz}; mask_idc[mask_idc]; "
|
||||
f"index={indices[i] if indices is not None else None}"
|
||||
)
|
||||
)
|
||||
mask_idcs.append(mask_idc)
|
||||
|
||||
target_len = None
|
||||
if require_same_masks:
|
||||
if add_masks:
|
||||
target_len = max([len(m) for m in mask_idcs])
|
||||
else:
|
||||
target_len = min([len(m) for m in mask_idcs])
|
||||
|
||||
for i, mask_idc in enumerate(mask_idcs):
|
||||
if target_len is not None and len(mask_idc) > target_len:
|
||||
mask_idc = rng.choice(mask_idc, target_len, replace=False)
|
||||
|
||||
mask[i, mask_idc] = True
|
||||
|
||||
if target_len is not None and len(mask_idc) < target_len:
|
||||
unmasked = np.flatnonzero(~mask[i])
|
||||
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
|
||||
mask[i, to_mask] = True
|
||||
|
||||
if mask_dropout > 0:
|
||||
masked = np.flatnonzero(mask[i])
|
||||
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
|
||||
to_drop = rng.choice(masked, num_holes, replace=False)
|
||||
mask[i, to_drop] = False
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def _to_int_tuple(s: str):
|
||||
return tuple(map(int, s.split(",")))
|
||||
|
||||
|
||||
class HubertModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed = feature_enc_layers[-1][0]
|
||||
|
||||
self.encoder_embed = Conv2dSubsampling(
|
||||
in_channels=cfg.feature_dim,
|
||||
out_channels=_to_int_tuple(cfg.encoder_dim)[0],
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
)
|
||||
self.feat2tar_ratio = (
|
||||
cfg.label_rate * feature_ds_rate / cfg.sample_rate
|
||||
) # TODO feature_ds_rate 320
|
||||
encoder_input_dim = _to_int_tuple(cfg.encoder_dim)[0]
|
||||
encoder_output_dim = max(_to_int_tuple(cfg.encoder_dim))
|
||||
|
||||
self.mask_prob = cfg.mask_prob
|
||||
self.mask_selection = cfg.mask_selection
|
||||
self.mask_other = cfg.mask_other
|
||||
self.mask_length = cfg.mask_length
|
||||
self.no_mask_overlap = cfg.no_mask_overlap
|
||||
self.mask_min_space = cfg.mask_min_space
|
||||
|
||||
self.mask_channel_prob = cfg.mask_channel_prob
|
||||
self.mask_channel_selection = cfg.mask_channel_selection
|
||||
self.mask_channel_other = cfg.mask_channel_other
|
||||
self.mask_channel_length = cfg.mask_channel_length
|
||||
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
||||
self.mask_channel_min_space = cfg.mask_channel_min_space
|
||||
|
||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
||||
|
||||
self.logit_temp = cfg.logit_temp
|
||||
self.skip_masked = cfg.skip_masked
|
||||
self.skip_nomask = cfg.skip_nomask
|
||||
|
||||
self.mask_emb = nn.Parameter(torch.FloatTensor(encoder_input_dim).uniform_())
|
||||
|
||||
self.encoder = Zipformer2(
|
||||
output_downsampling_factor=1,
|
||||
downsampling_factor=_to_int_tuple(cfg.downsampling_factor),
|
||||
num_encoder_layers=_to_int_tuple(cfg.num_encoder_layers),
|
||||
encoder_dim=_to_int_tuple(cfg.encoder_dim),
|
||||
encoder_unmasked_dim=_to_int_tuple(cfg.encoder_unmasked_dim),
|
||||
query_head_dim=_to_int_tuple(cfg.query_head_dim),
|
||||
pos_head_dim=_to_int_tuple(cfg.pos_head_dim),
|
||||
value_head_dim=_to_int_tuple(cfg.value_head_dim),
|
||||
pos_dim=cfg.pos_dim,
|
||||
num_heads=_to_int_tuple(cfg.num_heads),
|
||||
feedforward_dim=_to_int_tuple(cfg.feedforward_dim),
|
||||
cnn_module_kernel=_to_int_tuple(cfg.cnn_module_kernel),
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
)
|
||||
|
||||
self.layer_norm = LayerNorm(self.embed)
|
||||
|
||||
self.untie_final_proj = cfg.untie_final_proj
|
||||
self.final_proj = nn.Linear(encoder_output_dim, sum(cfg.num_classes))
|
||||
|
||||
# modules below are not needed during fine-tuning
|
||||
self.num_classes = cfg.num_classes
|
||||
self.pred_masked_weight = cfg.pred_masked_weight
|
||||
self.pred_nomask_weight = cfg.pred_nomask_weight
|
||||
self.loss_weights = cfg.loss_weights
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
||||
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
return state_dict
|
||||
|
||||
def apply_mask(self, x, padding_mask, target_list):
|
||||
B, T, C = x.shape
|
||||
if self.mask_prob > 0:
|
||||
mask_indices = compute_mask_indices(
|
||||
(B, T),
|
||||
padding_mask,
|
||||
self.mask_prob,
|
||||
self.mask_length,
|
||||
self.mask_selection,
|
||||
self.mask_other,
|
||||
min_masks=2,
|
||||
no_overlap=self.no_mask_overlap,
|
||||
min_space=self.mask_min_space,
|
||||
)
|
||||
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
||||
x[mask_indices] = self.mask_emb.to(x.dtype)
|
||||
else:
|
||||
mask_indices = None
|
||||
|
||||
if self.mask_channel_prob > 0:
|
||||
mask_channel_indices = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
self.mask_channel_prob,
|
||||
self.mask_channel_length,
|
||||
self.mask_channel_selection,
|
||||
self.mask_channel_other,
|
||||
no_overlap=self.no_mask_channel_overlap,
|
||||
min_space=self.mask_channel_min_space,
|
||||
)
|
||||
mask_channel_indices = (
|
||||
torch.from_numpy(mask_channel_indices)
|
||||
.to(x.device)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, T, -1)
|
||||
)
|
||||
x[mask_channel_indices] = 0
|
||||
|
||||
return x, mask_indices
|
||||
|
||||
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
|
||||
features = self.encoder_embed(source)
|
||||
return features
|
||||
|
||||
def forward_targets(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
target_list: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Trim features to ensure labels exist and then get aligned labels
|
||||
feat_tsz = features.size(2)
|
||||
targ_tsz = min([t.size(1) for t in target_list])
|
||||
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
||||
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
||||
features = features[..., :feat_tsz]
|
||||
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
||||
target_list = [t[:, target_inds.long()] for t in target_list]
|
||||
return features, target_list
|
||||
|
||||
def forward_padding_mask(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
extra = padding_mask.size(1) % features.size(1)
|
||||
if extra > 0:
|
||||
padding_mask = padding_mask[:, :-extra]
|
||||
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
||||
padding_mask = padding_mask.all(-1)
|
||||
return padding_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
target_list: Optional[List[torch.Tensor]] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
mask: bool = True,
|
||||
features_only: bool = False,
|
||||
output_layer: Optional[int] = None,
|
||||
):
|
||||
"""output layer is 1-based"""
|
||||
features = self.forward_features(source)
|
||||
if target_list is not None:
|
||||
features, target_list = self.forward_targets(features, target_list)
|
||||
|
||||
features_pen = features.float().pow(2).mean()
|
||||
|
||||
features = features.transpose(1, 2)
|
||||
features = self.layer_norm(features)
|
||||
unmasked_features = features.clone()
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||
|
||||
features = self.dropout_input(features)
|
||||
unmasked_features = self.dropout_features(unmasked_features)
|
||||
|
||||
if mask:
|
||||
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
|
||||
else:
|
||||
x = features
|
||||
mask_indices = None
|
||||
|
||||
# feature: (B, T, D), float
|
||||
# target: (B, T), long
|
||||
# x: (B, T, D), float -> (T, B, D), float
|
||||
# padding_mask: (B, T), bool
|
||||
# mask_indices: (B, T), bool
|
||||
x = x.transpose(0, 1)
|
||||
x, x_lens = self.encoder(x, ~padding_mask.sum(dim=-1))
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
if features_only:
|
||||
return {"x": x, "padding_mask": padding_mask, "features": features}
|
||||
|
||||
if not self.skip_masked:
|
||||
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
||||
proj_x_m = self.final_proj(x[masked_indices])
|
||||
proj_x_m /= self.logit_temp
|
||||
logit_m_list = [proj_x_m for _ in range(len(target_list))]
|
||||
else:
|
||||
logit_m_list = [None for _ in target_list]
|
||||
|
||||
if not self.skip_nomask:
|
||||
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
|
||||
proj_x_u = self.final_proj(x[nomask_indices])
|
||||
proj_x_u /= self.logit_temp
|
||||
logit_u_list = [proj_x_u for _ in range(len(target_list))]
|
||||
else:
|
||||
logit_u_list = [None for _ in target_list]
|
||||
|
||||
# result = {
|
||||
# "logit_m_list": logit_m_list,
|
||||
# "logit_u_list": logit_u_list,
|
||||
# "padding_mask": padding_mask,
|
||||
# "features_pen": features_pen,
|
||||
# }
|
||||
targ_m_list = target_list[0][masked_indices]
|
||||
targ_m_list = targ_m_list.long()
|
||||
targ_m_list = [targ_m_list for _ in range(len(target_list))]
|
||||
|
||||
targ_u_list = target_list[0][nomask_indices]
|
||||
targ_u_list = targ_u_list.long()
|
||||
targ_u_list = [targ_u_list for _ in range(len(target_list))]
|
||||
return self.compute_loss(
|
||||
logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
|
||||
)
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
mask: bool = False,
|
||||
ret_conv: bool = False,
|
||||
output_layer: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
res = self.forward(
|
||||
source,
|
||||
padding_mask=padding_mask,
|
||||
mask=mask,
|
||||
features_only=True,
|
||||
output_layer=output_layer,
|
||||
)
|
||||
feature = res["features"] if ret_conv else res["x"]
|
||||
return feature, res["padding_mask"]
|
||||
|
||||
def get_logits(self, net_output, is_masked=True):
|
||||
if is_masked:
|
||||
logits_list = net_output["logit_m_list"]
|
||||
else:
|
||||
logits_list = net_output["logit_u_list"]
|
||||
logits_list = [x.float() for x in logits_list if x is not None]
|
||||
return logits_list
|
||||
|
||||
def get_targets(self, net_output, is_masked=True):
|
||||
logits_list = self.get_logits(net_output, is_masked)
|
||||
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
|
||||
return targets_list
|
||||
|
||||
def get_extra_losses(self, net_output):
|
||||
extra_losses = []
|
||||
names = []
|
||||
|
||||
if "features_pen" in net_output:
|
||||
extra_losses.append(net_output["features_pen"])
|
||||
names.append("features_pen")
|
||||
|
||||
return extra_losses, names
|
||||
|
||||
def remove_pretraining_modules(self):
|
||||
self.final_proj = None
|
||||
|
||||
def compute_loss(
|
||||
self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
|
||||
):
|
||||
loss = 0.0
|
||||
sample_size = 0
|
||||
logging_output = {}
|
||||
reduce = True
|
||||
reduction = "sum" if reduce else "none"
|
||||
|
||||
loss_m_list = []
|
||||
logp_m_list = [x.float() for x in logit_m_list if x is not None]
|
||||
logp_m_list = torch.cat(logp_m_list)
|
||||
targ_m_list = torch.cat(targ_m_list)
|
||||
|
||||
loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction)
|
||||
loss_m_list.append(loss_m)
|
||||
logging_output[f"loss_m_0"] = loss_m.detach().item()
|
||||
|
||||
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
|
||||
if self.pred_masked_weight > 0:
|
||||
loss += self.pred_masked_weight * sum(loss_m_list)
|
||||
sample_size += len(targ_m_list)
|
||||
|
||||
loss_u_list = []
|
||||
logp_u_list = [x.float() for x in logit_u_list if x is not None]
|
||||
logp_u_list = torch.cat(logp_u_list)
|
||||
targ_u_list = torch.cat(targ_u_list)
|
||||
|
||||
loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction)
|
||||
loss_u_list.append(loss_u)
|
||||
logging_output[f"loss_u_0"] = loss_u.detach().item()
|
||||
|
||||
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
|
||||
if self.pred_nomask_weight > 0:
|
||||
loss += self.pred_nomask_weight * sum(loss_u_list)
|
||||
sample_size += len(targ_u_list)
|
||||
|
||||
if self.loss_weights is not None:
|
||||
extra_losses = []
|
||||
names = []
|
||||
extra_losses.append(features_pen)
|
||||
names.append("features_pen")
|
||||
if torch.is_tensor(extra_losses):
|
||||
extra_losses = [extra_losses]
|
||||
names = [names]
|
||||
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
||||
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
||||
assert len(extra_losses) == len(
|
||||
self.loss_weights
|
||||
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
||||
for p, n, coef in zip(extra_losses, names, self.loss_weights):
|
||||
if coef != 0 and p is not None:
|
||||
p = coef * p.float() * sample_size
|
||||
loss += p
|
||||
logging_output[f"loss_{n}"] = p.item()
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.item() if reduce else loss,
|
||||
**logging_output,
|
||||
}
|
||||
|
||||
# for lk in self.log_keys:
|
||||
# if lk in net_output:
|
||||
# logging_output[lk] = float((net_output[lk]))
|
||||
|
||||
def compute_correct(logits, target):
|
||||
if logits.numel() == 0:
|
||||
return 0, 0
|
||||
else:
|
||||
assert logits.dim() > 1, logits.shape
|
||||
max = logits.argmax(-1) == target
|
||||
min = logits.argmin(-1) == target
|
||||
both = max & min
|
||||
corr = max.long().sum().item() - both.long().sum().item()
|
||||
count = max.numel()
|
||||
return corr, count
|
||||
|
||||
with torch.no_grad():
|
||||
corr_m, count_m = compute_correct(logp_m_list, targ_m_list)
|
||||
logging_output[f"correct_m_0"] = corr_m
|
||||
logging_output[f"count_m_0"] = count_m
|
||||
|
||||
corr_u, count_u = compute_correct(logp_u_list, targ_u_list)
|
||||
logging_output[f"correct_u_0"] = corr_u
|
||||
logging_output[f"count_u_0"] = count_u
|
||||
|
||||
return loss, sample_size, logging_output
|
1
egs/gigaspeech2/SSL/zipformer/joiner.py
Symbolic link
1
egs/gigaspeech2/SSL/zipformer/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/joiner.py
|
344
egs/gigaspeech2/SSL/zipformer/model.py
Normal file
344
egs/gigaspeech2/SSL/zipformer/model.py
Normal file
@ -0,0 +1,344 @@
|
||||
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Zengwei Yao,
|
||||
# Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder,
|
||||
decoder: Optional[nn.Module] = None,
|
||||
joiner: Optional[nn.Module] = None,
|
||||
encoder_dim: int = 768,
|
||||
decoder_dim: int = 512,
|
||||
vocab_size: int = 500,
|
||||
use_transducer: bool = True,
|
||||
use_ctc: bool = False,
|
||||
):
|
||||
"""A joint CTC & Transducer ASR model.
|
||||
|
||||
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
|
||||
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
|
||||
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
|
||||
|
||||
Args:
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
inputs: `x` of (N, T, encoder_dim).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
It is used when use_transducer is True.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
It is used when use_transducer is True.
|
||||
use_transducer:
|
||||
Whether use transducer head. Default: True.
|
||||
use_ctc:
|
||||
Whether use CTC head. Default: False.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
use_transducer or use_ctc
|
||||
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
|
||||
|
||||
self.encoder = encoder
|
||||
|
||||
self.use_transducer = use_transducer
|
||||
if use_transducer:
|
||||
# Modules for Transducer head
|
||||
assert decoder is not None
|
||||
assert hasattr(decoder, "blank_id")
|
||||
assert joiner is not None
|
||||
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim, vocab_size, initial_scale=0.25
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(
|
||||
decoder_dim, vocab_size, initial_scale=0.25
|
||||
)
|
||||
else:
|
||||
assert decoder is None
|
||||
assert joiner is None
|
||||
|
||||
self.use_ctc = use_ctc
|
||||
if use_ctc:
|
||||
# Modules for CTC head
|
||||
self.ctc_output = nn.Sequential(
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(encoder_dim, vocab_size),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
def forward_encoder(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute encoder outputs.
|
||||
Args:
|
||||
x:
|
||||
A 2-D tensor of shape (N, T).
|
||||
|
||||
Returns:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
"""
|
||||
if padding_mask is None:
|
||||
padding_mask = torch.zeros_like(x, dtype=torch.bool)
|
||||
|
||||
encoder_out, padding_mask = self.encoder.extract_features(
|
||||
source=x,
|
||||
padding_mask=padding_mask,
|
||||
mask=self.encoder.training,
|
||||
)
|
||||
encoder_out_lens = torch.sum(~padding_mask, dim=1)
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def forward_ctc(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
targets:
|
||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||
to be un-padded and concatenated within 1 dimension.
|
||||
"""
|
||||
# Compute CTC log-prob
|
||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||
targets=targets,
|
||||
input_lengths=encoder_out_lens,
|
||||
target_lengths=target_lengths,
|
||||
reduction="sum",
|
||||
)
|
||||
return ctc_loss
|
||||
|
||||
def forward_transducer(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
y_lens: torch.Tensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute Transducer loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
"""
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(encoder_out.size(0), 4),
|
||||
dtype=torch.int64,
|
||||
device=encoder_out.device,
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = encoder_out_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return simple_loss, pruned_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 2-D tensor of shape (N, T).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer losses and CTC loss,
|
||||
in form of (simple_loss, pruned_loss, ctc_loss)
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 2, x.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == y.dim0, (x.shape, y.dim0)
|
||||
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask)
|
||||
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
if self.use_transducer:
|
||||
# Compute transducer loss
|
||||
simple_loss, pruned_loss = self.forward_transducer(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
y=y.to(x.device),
|
||||
y_lens=y_lens,
|
||||
prune_range=prune_range,
|
||||
am_scale=am_scale,
|
||||
lm_scale=lm_scale,
|
||||
)
|
||||
else:
|
||||
simple_loss = torch.empty(0)
|
||||
pruned_loss = torch.empty(0)
|
||||
|
||||
if self.use_ctc:
|
||||
# Compute CTC loss
|
||||
targets = y.values
|
||||
ctc_loss = self.forward_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
)
|
||||
else:
|
||||
ctc_loss = torch.empty(0)
|
||||
|
||||
return simple_loss, pruned_loss, ctc_loss, encoder_out_lens
|
1
egs/gigaspeech2/SSL/zipformer/optim.py
Symbolic link
1
egs/gigaspeech2/SSL/zipformer/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/optim.py
|
1351
egs/gigaspeech2/SSL/zipformer/pretrain.py
Normal file
1351
egs/gigaspeech2/SSL/zipformer/pretrain.py
Normal file
File diff suppressed because it is too large
Load Diff
1
egs/gigaspeech2/SSL/zipformer/scaling.py
Symbolic link
1
egs/gigaspeech2/SSL/zipformer/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/scaling.py
|
341
egs/gigaspeech2/SSL/zipformer/ssl_datamodule.py
Normal file
341
egs/gigaspeech2/SSL/zipformer/ssl_datamodule.py
Normal file
@ -0,0 +1,341 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from dataset import HubertDataset
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriSpeechDataModule:
|
||||
"""
|
||||
DataModule for SSL experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in SSL
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
|
||||
This class should be derived for specific corpora used in SSL tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="SSL data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled use 960h LibriSpeech. " "Otherwise, use 100h subset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/kmeans"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=float,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--do-normalize",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="whether to normalize the data",
|
||||
)
|
||||
group.add_argument(
|
||||
"--random-crop",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="audio sample rate",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
max_sample_size: Optional[int] = None,
|
||||
sample_rate: float = 16000,
|
||||
label_rate: float = 50,
|
||||
random_crop: bool = True,
|
||||
pad_audio: bool = False,
|
||||
num_classes: list = [504],
|
||||
do_normalize: bool = True,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = HubertDataset(
|
||||
max_sample_size=max_sample_size,
|
||||
sample_rate=sample_rate,
|
||||
label_rate=label_rate,
|
||||
random_crop=random_crop,
|
||||
pad_audio=pad_audio,
|
||||
num_classes=num_classes,
|
||||
do_normalize=do_normalize,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(
|
||||
self,
|
||||
cuts_valid: CutSet,
|
||||
max_sample_size: Optional[int] = None,
|
||||
sample_rate: float = 16000,
|
||||
label_rate: float = 50,
|
||||
random_crop: bool = True,
|
||||
pad_audio: bool = False,
|
||||
num_classes: list = [504],
|
||||
do_normalize: bool = True,
|
||||
) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
validate = HubertDataset(
|
||||
max_sample_size=max_sample_size,
|
||||
sample_rate=sample_rate,
|
||||
label_rate=label_rate,
|
||||
random_crop=random_crop,
|
||||
pad_audio=pad_audio,
|
||||
num_classes=num_classes,
|
||||
do_normalize=do_normalize,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(
|
||||
self,
|
||||
cuts: CutSet,
|
||||
sample_rate: float = 16000,
|
||||
label_rate: float = 50,
|
||||
random_crop: bool = True,
|
||||
pad_audio: bool = False,
|
||||
num_classes: list = [504],
|
||||
do_normalize: bool = True,
|
||||
) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = HubertDataset(
|
||||
sample_rate=sample_rate,
|
||||
label_rate=label_rate,
|
||||
random_crop=random_crop,
|
||||
pad_audio=pad_audio,
|
||||
num_classes=num_classes,
|
||||
do_normalize=do_normalize,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_360_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-360 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_other_500_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-other-500 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_all_shuf_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled train-clean-100, \
|
||||
train-clean-360 and train-other-500 cuts"
|
||||
)
|
||||
train_clean_100_cuts = self.train_clean_100_cuts()
|
||||
train_clean_360_cuts = self.train_clean_360_cuts()
|
||||
train_other_500_cuts = self.train_other_500_cuts()
|
||||
return CutSet.mux(
|
||||
train_clean_100_cuts,
|
||||
train_clean_360_cuts,
|
||||
train_other_500_cuts,
|
||||
weights=[
|
||||
28539, # len(train_clean_100_cuts)
|
||||
104014, # len(train_clean_360_cuts)
|
||||
148688, # len(train_other_500_cuts)
|
||||
],
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||
)
|
1
egs/gigaspeech2/SSL/zipformer/subsampling.py
Symbolic link
1
egs/gigaspeech2/SSL/zipformer/subsampling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/subsampling.py
|
1
egs/gigaspeech2/SSL/zipformer/utils.py
Symbolic link
1
egs/gigaspeech2/SSL/zipformer/utils.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/utils.py
|
2438
egs/gigaspeech2/SSL/zipformer/zipformer.py
Normal file
2438
egs/gigaspeech2/SSL/zipformer/zipformer.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user