mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
k2SSL: a Faster and Better Framework for Self-Supervised Speech Representation Learning (#1500)
* Add k2SSL * fix flake8 * fix for black * fix for black * fix for black * Update ssl_datamodule.py * Fix bugs in HubertDataset * update comments * add librilight * add checkpoint convert script * format --------- Co-authored-by: yifanyeung <yifanyeung@yifanyeung.local> Co-authored-by: zzasdf <15218404468@163.com>
This commit is contained in:
parent
c45e9fecfb
commit
87843e9382
1
egs/librilight/SSL/zipformer/asr_datamodule.py
Symbolic link
1
egs/librilight/SSL/zipformer/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/asr_datamodule.py
|
1
egs/librilight/SSL/zipformer/beam_search.py
Symbolic link
1
egs/librilight/SSL/zipformer/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/beam_search.py
|
1
egs/librilight/SSL/zipformer/dataset.py
Symbolic link
1
egs/librilight/SSL/zipformer/dataset.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/dataset.py
|
1045
egs/librilight/SSL/zipformer/decode.py
Normal file
1045
egs/librilight/SSL/zipformer/decode.py
Normal file
File diff suppressed because it is too large
Load Diff
1
egs/librilight/SSL/zipformer/decoder.py
Symbolic link
1
egs/librilight/SSL/zipformer/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/decoder.py
|
1
egs/librilight/SSL/zipformer/encoder_interface.py
Symbolic link
1
egs/librilight/SSL/zipformer/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/encoder_interface.py
|
1552
egs/librilight/SSL/zipformer/finetune.py
Normal file
1552
egs/librilight/SSL/zipformer/finetune.py
Normal file
File diff suppressed because it is too large
Load Diff
1
egs/librilight/SSL/zipformer/hubert_ce.py
Symbolic link
1
egs/librilight/SSL/zipformer/hubert_ce.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/hubert_ce.py
|
1
egs/librilight/SSL/zipformer/joiner.py
Symbolic link
1
egs/librilight/SSL/zipformer/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/joiner.py
|
1
egs/librilight/SSL/zipformer/model.py
Symbolic link
1
egs/librilight/SSL/zipformer/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/model.py
|
1
egs/librilight/SSL/zipformer/optim.py
Symbolic link
1
egs/librilight/SSL/zipformer/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/optim.py
|
1366
egs/librilight/SSL/zipformer/pretrain.py
Normal file
1366
egs/librilight/SSL/zipformer/pretrain.py
Normal file
File diff suppressed because it is too large
Load Diff
1
egs/librilight/SSL/zipformer/scaling.py
Symbolic link
1
egs/librilight/SSL/zipformer/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/scaling.py
|
334
egs/librilight/SSL/zipformer/ssl_datamodule.py
Normal file
334
egs/librilight/SSL/zipformer/ssl_datamodule.py
Normal file
@ -0,0 +1,334 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from dataset import HubertDataset
|
||||
from lhotse import CutSet, combine, load_manifest_lazy
|
||||
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriLightDataModule:
|
||||
"""
|
||||
DataModule for SSL experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in SSL
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
|
||||
This class should be derived for specific corpora used in SSL tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR SSL related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/kmeans"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=float,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--do-normalize",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="whether to normalize the data",
|
||||
)
|
||||
group.add_argument(
|
||||
"--random-crop",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="audio sample rate",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sample_rate: float = 16000,
|
||||
label_rate: float = 50,
|
||||
random_crop: bool = True,
|
||||
pad_audio: bool = False,
|
||||
num_classes: list = [504],
|
||||
do_normalize: bool = True,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = HubertDataset(
|
||||
sample_rate=sample_rate,
|
||||
label_rate=label_rate,
|
||||
random_crop=random_crop,
|
||||
pad_audio=pad_audio,
|
||||
num_classes=num_classes,
|
||||
do_normalize=do_normalize,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(
|
||||
self,
|
||||
cuts_valid: CutSet,
|
||||
sample_rate: float = 16000,
|
||||
label_rate: float = 50,
|
||||
random_crop: bool = True,
|
||||
pad_audio: bool = False,
|
||||
num_classes: list = [504],
|
||||
do_normalize: bool = True,
|
||||
) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
validate = HubertDataset(
|
||||
sample_rate=sample_rate,
|
||||
label_rate=label_rate,
|
||||
random_crop=random_crop,
|
||||
pad_audio=pad_audio,
|
||||
num_classes=num_classes,
|
||||
do_normalize=do_normalize,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(
|
||||
self,
|
||||
cuts: CutSet,
|
||||
sample_rate: float = 16000,
|
||||
label_rate: float = 50,
|
||||
random_crop: bool = True,
|
||||
pad_audio: bool = False,
|
||||
num_classes: list = [504],
|
||||
do_normalize: bool = True,
|
||||
) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = HubertDataset(
|
||||
sample_rate=sample_rate,
|
||||
label_rate=label_rate,
|
||||
random_crop=random_crop,
|
||||
pad_audio=pad_audio,
|
||||
num_classes=num_classes,
|
||||
do_normalize=do_normalize,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def small_cuts(self) -> CutSet:
|
||||
logging.info("About to get small cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librilight_cuts_small.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def medium_cuts(self) -> CutSet:
|
||||
logging.info("About to get medium cuts")
|
||||
filenames = glob.glob(
|
||||
f"{self.args.manifest_dir}/medium_splits/librilight_cuts_medium.*.jsonl.gz"
|
||||
)
|
||||
pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz")
|
||||
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
||||
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
||||
sorted_filenames = [f[1] for f in idx_filenames]
|
||||
logging.info(
|
||||
f"Loading LibriLight medium {len(sorted_filenames)} splits in lazy mode"
|
||||
)
|
||||
|
||||
return combine(load_manifest_lazy(p) for p in sorted_filenames)
|
||||
|
||||
@lru_cache()
|
||||
def large_cuts(self) -> CutSet:
|
||||
logging.info("About to get large cuts")
|
||||
filenames = glob.glob(
|
||||
f"{self.args.manifest_dir}/large_splits/librilight_cuts_large.*.jsonl.gz"
|
||||
)
|
||||
pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz")
|
||||
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
||||
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
||||
sorted_filenames = [f[1] for f in idx_filenames]
|
||||
logging.info(
|
||||
f"Loading LibriLight large {len(sorted_filenames)} splits in lazy mode"
|
||||
)
|
||||
|
||||
return combine(load_manifest_lazy(p) for p in sorted_filenames)
|
||||
|
||||
@lru_cache()
|
||||
def train_all_shuf_cuts(self) -> CutSet:
|
||||
logging.info("About to get the shuffled small, medium and large cuts")
|
||||
small_cuts = self.small_cuts()
|
||||
medium_cuts = self.medium_cuts()
|
||||
large_cuts = self.large_cuts()
|
||||
return CutSet.mux(
|
||||
small_cuts,
|
||||
medium_cuts,
|
||||
large_cuts,
|
||||
weights=[
|
||||
122867, # len(small_cuts)
|
||||
1104071, # len(medium_cuts)
|
||||
11012085, # len(large_cuts)
|
||||
],
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
||||
)
|
1
egs/librilight/SSL/zipformer/utils.py
Symbolic link
1
egs/librilight/SSL/zipformer/utils.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/utils.py
|
1
egs/librilight/SSL/zipformer/wav2vec2_module.py
Symbolic link
1
egs/librilight/SSL/zipformer/wav2vec2_module.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/wav2vec2_module.py
|
1
egs/librilight/SSL/zipformer/zipformer.py
Symbolic link
1
egs/librilight/SSL/zipformer/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/SSL/zipformer/zipformer.py
|
287
egs/librispeech/SSL/hubert/asr_datamodule.py
Normal file
287
egs/librispeech/SSL/hubert/asr_datamodule.py
Normal file
@ -0,0 +1,287 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2024 Xiaomi Corporation (Author: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from dataset import HubertAsrDataset
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule:
|
||||
"""
|
||||
DataModule for ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled use 960h LibriSpeech. " "Otherwise, use 100h subset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/wav"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=float,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--do-normalize",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="whether to normalize the data",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
do_normalize: bool,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = HubertAsrDataset(do_normalize=do_normalize)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet, do_normalize: bool) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
validate = HubertAsrDataset(do_normalize=do_normalize)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet, do_normalize: bool) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = HubertAsrDataset(do_normalize=do_normalize)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_360_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-360 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_other_500_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-other-500 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_all_shuf_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled train-clean-100, \
|
||||
train-clean-360 and train-other-500 cuts"
|
||||
)
|
||||
train_clean_100_cuts = self.train_clean_100_cuts()
|
||||
train_clean_360_cuts = self.train_clean_360_cuts()
|
||||
train_other_500_cuts = self.train_other_500_cuts()
|
||||
return CutSet.mux(
|
||||
train_clean_100_cuts,
|
||||
train_clean_360_cuts,
|
||||
train_other_500_cuts,
|
||||
weights=[
|
||||
28539, # len(train_clean_100_cuts)
|
||||
104014, # len(train_clean_360_cuts)
|
||||
148688, # len(train_other_500_cuts)
|
||||
],
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||
)
|
840
egs/librispeech/SSL/hubert/attention_module.py
Normal file
840
egs/librispeech/SSL/hubert/attention_module.py
Normal file
@ -0,0 +1,840 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import utils
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import Parameter
|
||||
from utils import FairseqDropout, quant_noise
|
||||
|
||||
_xformers_available = False
|
||||
|
||||
|
||||
# TODO: move this into xformers?
|
||||
# TODO: uint8 input type should just output a bool
|
||||
def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None):
|
||||
"""
|
||||
call to pytorch multihead accepts three mask types:
|
||||
- ByteTensor where non-zero means to mask
|
||||
- FloatTensor which is an additive mask
|
||||
- BoolTensor where True means to mask
|
||||
xFormers currently accepts boolean and additive maks. For boolean masks
|
||||
the values have opposite meaning. For a BoolTensor True mean to keep the value.
|
||||
"""
|
||||
float_types = [torch.float, torch.float16]
|
||||
# If an input mask is a float it is an additive mask. Otherwise it is either uint8 or bool.
|
||||
additive = mask.dtype in float_types
|
||||
# If to_dype is not specified, keep same dtype as mask.
|
||||
to_dtype = mask.dtype if to_dtype is None else to_dtype
|
||||
to_additive = to_dtype in float_types
|
||||
|
||||
if additive:
|
||||
if to_additive:
|
||||
return mask.to(to_dtype)
|
||||
mask = mask < 0
|
||||
|
||||
if to_additive:
|
||||
# return additive mask
|
||||
new_mask = torch.zeros_like(mask, dtype=to_dtype)
|
||||
new_mask = new_mask.masked_fill_(mask, -float("inf"))
|
||||
return new_mask
|
||||
|
||||
# In xFormers True is value to keep rather than value to mask
|
||||
mask = ~mask.to(torch.bool)
|
||||
mask = mask.to(to_dtype)
|
||||
return mask
|
||||
|
||||
|
||||
def init_bert_params(module):
|
||||
"""
|
||||
Initialize the weights specific to the BERT Model.
|
||||
This overrides the default initializations depending on the specified arguments.
|
||||
1. If normal_init_linear_weights is set then weights of linear
|
||||
layer will be initialized using the normal distribution and
|
||||
bais will be set to the specified value.
|
||||
2. If normal_init_embed_weights is set then weights of embedding
|
||||
layer will be initialized using the normal distribution.
|
||||
3. If normal_init_proj_weights is set then weights of
|
||||
in_project_weight for MultiHeadAttention initialized using
|
||||
the normal distribution (to be validated).
|
||||
"""
|
||||
|
||||
def normal_(data):
|
||||
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
||||
# so that the RNG is consistent with and without FSDP
|
||||
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
normal_(module.weight.data)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if isinstance(module, nn.Embedding):
|
||||
normal_(module.weight.data)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if isinstance(module, MultiheadAttention):
|
||||
normal_(module.q_proj.weight.data)
|
||||
normal_(module.k_proj.weight.data)
|
||||
normal_(module.v_proj.weight.data)
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""Multi-headed attention.
|
||||
|
||||
See "Attention Is All You Need" for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
self_attention=False,
|
||||
encoder_decoder_attention=False,
|
||||
dictionary=None,
|
||||
q_noise=0.0,
|
||||
qn_block_size=8,
|
||||
# TODO: pass in config rather than string.
|
||||
# config defined in xformers.components.attention.AttentionConfig
|
||||
xformers_att_config: Optional[str] = None,
|
||||
xformers_blocksparse_layout: Optional[
|
||||
torch.Tensor
|
||||
] = None, # This should be part of the config
|
||||
xformers_blocksparse_blocksize: Optional[
|
||||
int
|
||||
] = 16, # This should be part of the config
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_xformers = False
|
||||
if self.use_xformers and not _xformers_available:
|
||||
raise ImportError("\n\n Please install xFormers.")
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout_module = FairseqDropout(
|
||||
dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.self_attention = self_attention
|
||||
self.encoder_decoder_attention = encoder_decoder_attention
|
||||
|
||||
assert (
|
||||
not self.self_attention or self.qkv_same_dim
|
||||
), "Self-attention requires query, key and value to be of the same size"
|
||||
|
||||
self.k_proj = quant_noise(
|
||||
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
self.v_proj = quant_noise(
|
||||
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
self.q_proj = quant_noise(
|
||||
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
|
||||
self.out_proj = quant_noise(
|
||||
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
self.add_zero_attn = add_zero_attn
|
||||
self.beam_size = 1
|
||||
self.reset_parameters()
|
||||
|
||||
self.onnx_trace = False
|
||||
self.skip_embed_dim_check = False
|
||||
|
||||
def prepare_for_onnx_export_(self):
|
||||
self.onnx_trace = True
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.qkv_same_dim:
|
||||
# Empirically observed the convergence to be much better with
|
||||
# the scaled initialization
|
||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||
else:
|
||||
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
if self.out_proj.bias is not None:
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
if self.bias_k is not None:
|
||||
nn.init.xavier_normal_(self.bias_k)
|
||||
if self.bias_v is not None:
|
||||
nn.init.xavier_normal_(self.bias_v)
|
||||
|
||||
def _get_reserve_head_index(self, num_heads_to_keep: int):
|
||||
k_proj_heads_norm = []
|
||||
q_proj_heads_norm = []
|
||||
v_proj_heads_norm = []
|
||||
|
||||
for i in range(self.num_heads):
|
||||
start_idx = i * self.head_dim
|
||||
end_idx = (i + 1) * self.head_dim
|
||||
k_proj_heads_norm.append(
|
||||
torch.sum(
|
||||
torch.abs(
|
||||
self.k_proj.weight[
|
||||
start_idx:end_idx,
|
||||
]
|
||||
)
|
||||
).tolist()
|
||||
+ torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist()
|
||||
)
|
||||
q_proj_heads_norm.append(
|
||||
torch.sum(
|
||||
torch.abs(
|
||||
self.q_proj.weight[
|
||||
start_idx:end_idx,
|
||||
]
|
||||
)
|
||||
).tolist()
|
||||
+ torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist()
|
||||
)
|
||||
v_proj_heads_norm.append(
|
||||
torch.sum(
|
||||
torch.abs(
|
||||
self.v_proj.weight[
|
||||
start_idx:end_idx,
|
||||
]
|
||||
)
|
||||
).tolist()
|
||||
+ torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist()
|
||||
)
|
||||
|
||||
heads_norm = []
|
||||
for i in range(self.num_heads):
|
||||
heads_norm.append(
|
||||
k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i]
|
||||
)
|
||||
|
||||
sorted_head_index = sorted(
|
||||
range(self.num_heads), key=lambda k: heads_norm[k], reverse=True
|
||||
)
|
||||
reserve_head_index = []
|
||||
for i in range(num_heads_to_keep):
|
||||
start = sorted_head_index[i] * self.head_dim
|
||||
end = (sorted_head_index[i] + 1) * self.head_dim
|
||||
reserve_head_index.append((start, end))
|
||||
return reserve_head_index
|
||||
|
||||
def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]):
|
||||
new_q_weight = []
|
||||
new_q_bias = []
|
||||
new_k_weight = []
|
||||
new_k_bias = []
|
||||
new_v_weight = []
|
||||
new_v_bias = []
|
||||
new_out_proj_weight = []
|
||||
|
||||
for ele in reserve_head_index:
|
||||
start_idx, end_idx = ele
|
||||
new_q_weight.append(
|
||||
self.q_proj.weight[
|
||||
start_idx:end_idx,
|
||||
]
|
||||
)
|
||||
new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
|
||||
|
||||
new_k_weight.append(
|
||||
self.k_proj.weight[
|
||||
start_idx:end_idx,
|
||||
]
|
||||
)
|
||||
|
||||
new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
|
||||
|
||||
new_v_weight.append(
|
||||
self.v_proj.weight[
|
||||
start_idx:end_idx,
|
||||
]
|
||||
)
|
||||
new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
|
||||
|
||||
new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
|
||||
|
||||
new_q_weight = torch.cat(new_q_weight).detach()
|
||||
new_k_weight = torch.cat(new_k_weight).detach()
|
||||
new_v_weight = torch.cat(new_v_weight).detach()
|
||||
new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach()
|
||||
new_q_weight.requires_grad = True
|
||||
new_k_weight.requires_grad = True
|
||||
new_v_weight.requires_grad = True
|
||||
new_out_proj_weight.requires_grad = True
|
||||
|
||||
new_q_bias = torch.cat(new_q_bias).detach()
|
||||
new_q_bias.requires_grad = True
|
||||
|
||||
new_k_bias = torch.cat(new_k_bias).detach()
|
||||
new_k_bias.requires_grad = True
|
||||
|
||||
new_v_bias = torch.cat(new_v_bias).detach()
|
||||
new_v_bias.requires_grad = True
|
||||
|
||||
self.q_proj.weight = torch.nn.Parameter(new_q_weight)
|
||||
self.q_proj.bias = torch.nn.Parameter(new_q_bias)
|
||||
|
||||
self.k_proj.weight = torch.nn.Parameter(new_k_weight)
|
||||
self.k_proj.bias = torch.nn.Parameter(new_k_bias)
|
||||
|
||||
self.v_proj.weight = torch.nn.Parameter(new_v_weight)
|
||||
self.v_proj.bias = torch.nn.Parameter(new_v_bias)
|
||||
|
||||
self.out_proj.weight = torch.nn.Parameter(new_out_proj_weight)
|
||||
|
||||
self.num_heads = len(reserve_head_index)
|
||||
self.embed_dim = self.head_dim * self.num_heads
|
||||
self.q_proj.out_features = self.embed_dim
|
||||
self.k_proj.out_features = self.embed_dim
|
||||
self.v_proj.out_features = self.embed_dim
|
||||
|
||||
def _set_skip_embed_dim_check(self):
|
||||
self.skip_embed_dim_check = True
|
||||
|
||||
def _pad_masks(
|
||||
self,
|
||||
key_padding_mask: Optional[Tensor],
|
||||
attn_mask: Optional[Tensor],
|
||||
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
|
||||
if attn_mask is not None:
|
||||
shape = attn_mask.size()[:-1] + torch.Size([1])
|
||||
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
|
||||
if key_padding_mask is not None:
|
||||
shape = key_padding_mask.size()[:-1] + torch.Size([1])
|
||||
key_padding_mask = torch.cat(
|
||||
[
|
||||
key_padding_mask,
|
||||
key_padding_mask.new_zeros(shape),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return key_padding_mask, attn_mask
|
||||
|
||||
def _add_bias(
|
||||
self,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
key_padding_mask: Optional[Tensor],
|
||||
attn_mask: Optional[Tensor],
|
||||
bsz: int,
|
||||
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
assert self.bias_k is not None
|
||||
assert self.bias_v is not None
|
||||
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
||||
key_padding_mask, attn_mask = self._pad_masks(
|
||||
key_padding_mask=key_padding_mask, attn_mask=attn_mask
|
||||
)
|
||||
return k, v, key_padding_mask, attn_mask
|
||||
|
||||
def _append_zero_attn(
|
||||
self,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
key_padding_mask: Optional[Tensor],
|
||||
attn_mask: Optional[Tensor],
|
||||
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
|
||||
k = torch.cat(
|
||||
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)],
|
||||
dim=-2,
|
||||
)
|
||||
v = torch.cat(
|
||||
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)],
|
||||
dim=-2,
|
||||
)
|
||||
key_padding_mask, attn_mask = self._pad_masks(
|
||||
key_padding_mask=key_padding_mask, attn_mask=attn_mask
|
||||
)
|
||||
return k, v, key_padding_mask, attn_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Optional[Tensor],
|
||||
value: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
need_weights: bool = True,
|
||||
static_kv: bool = False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
before_softmax: bool = False,
|
||||
need_head_weights: bool = False,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""Input shape: Time x Batch x Channel
|
||||
|
||||
Args:
|
||||
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||
keys that are pads, of shape `(batch, src_len)`, where
|
||||
padding elements are indicated by 1s.
|
||||
need_weights (bool, optional): return the attention weights,
|
||||
averaged over heads (default: False).
|
||||
attn_mask (ByteTensor, optional): typically used to
|
||||
implement causal attention, where the mask prevents the
|
||||
attention from looking forward in time (default: None).
|
||||
before_softmax (bool, optional): return the raw attention
|
||||
weights and values before the attention softmax.
|
||||
need_head_weights (bool, optional): return the attention
|
||||
weights for each head. Implies *need_weights*. Default:
|
||||
return the average attention weights over all heads.
|
||||
"""
|
||||
if need_head_weights:
|
||||
need_weights = True
|
||||
|
||||
is_tpu = query.device.type == "xla"
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
src_len = tgt_len
|
||||
if not self.skip_embed_dim_check:
|
||||
assert (
|
||||
embed_dim == self.embed_dim
|
||||
), f"query dim {embed_dim} != {self.embed_dim}"
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
if key is not None:
|
||||
src_len, key_bsz, _ = key.size()
|
||||
if not torch.jit.is_scripting():
|
||||
assert value is not None
|
||||
assert src_len, key_bsz == value.shape[:2]
|
||||
|
||||
if (
|
||||
not self.onnx_trace
|
||||
and not is_tpu # don't use PyTorch version on TPUs
|
||||
and incremental_state is None
|
||||
and not static_kv
|
||||
# A workaround for quantization to work. Otherwise JIT compilation
|
||||
# treats bias in linear module as method.
|
||||
and not torch.jit.is_scripting()
|
||||
# The Multihead attention implemented in pytorch forces strong dimension check
|
||||
# for input embedding dimention and K,Q,V projection dimension.
|
||||
# Since pruning will break the dimension check and it is not easy to modify the pytorch API,
|
||||
# it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check
|
||||
and not self.skip_embed_dim_check
|
||||
):
|
||||
assert key is not None and value is not None
|
||||
|
||||
return F.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
torch.empty([0]),
|
||||
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
||||
self.bias_k,
|
||||
self.bias_v,
|
||||
self.add_zero_attn,
|
||||
self.dropout_module.p,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
self.training or self.dropout_module.apply_during_inference,
|
||||
key_padding_mask.bool() if key_padding_mask is not None else None,
|
||||
need_weights,
|
||||
attn_mask,
|
||||
use_separate_proj_weight=True,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
)
|
||||
|
||||
if incremental_state is not None:
|
||||
saved_state = self._get_input_buffer(incremental_state)
|
||||
if saved_state is not None and "prev_key" in saved_state:
|
||||
# previous time steps are cached - no need to recompute
|
||||
# key and value if they are static
|
||||
if static_kv:
|
||||
assert self.encoder_decoder_attention and not self.self_attention
|
||||
key = value = None
|
||||
else:
|
||||
saved_state = None
|
||||
|
||||
if self.self_attention:
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(query)
|
||||
v = self.v_proj(query)
|
||||
elif self.encoder_decoder_attention:
|
||||
# encoder-decoder attention
|
||||
q = self.q_proj(query)
|
||||
if key is None:
|
||||
assert value is None
|
||||
k = v = None
|
||||
else:
|
||||
if self.beam_size > 1 and bsz == key.size(1):
|
||||
# key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
|
||||
key = key.view(key.size(0), -1, self.beam_size, key.size(2))[
|
||||
:, :, 0, :
|
||||
]
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = key_padding_mask.view(
|
||||
-1, self.beam_size, key_padding_mask.size(1)
|
||||
)[:, 0, :]
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(key)
|
||||
|
||||
else:
|
||||
assert key is not None and value is not None
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
q *= self.scaling
|
||||
|
||||
if self.bias_k is not None:
|
||||
assert self.bias_v is not None
|
||||
k, v, attn_mask, key_padding_mask = self._add_bias(
|
||||
k, v, attn_mask, key_padding_mask, bsz
|
||||
)
|
||||
|
||||
q = (
|
||||
q.contiguous()
|
||||
.view(tgt_len, bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
kv_bsz = bsz # need default value for scripting
|
||||
if k is not None:
|
||||
kv_bsz = k.size(1)
|
||||
k = (
|
||||
k.contiguous()
|
||||
.view(-1, kv_bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
if v is not None:
|
||||
v = (
|
||||
v.contiguous()
|
||||
.view(-1, kv_bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
if saved_state is not None:
|
||||
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||
if "prev_key" in saved_state:
|
||||
_prev_key = saved_state["prev_key"]
|
||||
assert _prev_key is not None
|
||||
kv_bsz = _prev_key.size(0)
|
||||
prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
k = prev_key
|
||||
else:
|
||||
assert k is not None
|
||||
k = torch.cat([prev_key, k], dim=1)
|
||||
src_len = k.size(1)
|
||||
if "prev_value" in saved_state:
|
||||
_prev_value = saved_state["prev_value"]
|
||||
assert _prev_value is not None
|
||||
assert kv_bsz == _prev_value.size(0)
|
||||
prev_value = _prev_value.view(
|
||||
kv_bsz * self.num_heads, -1, self.head_dim
|
||||
)
|
||||
if static_kv:
|
||||
v = prev_value
|
||||
else:
|
||||
assert v is not None
|
||||
v = torch.cat([prev_value, v], dim=1)
|
||||
prev_key_padding_mask: Optional[Tensor] = None
|
||||
if "prev_key_padding_mask" in saved_state:
|
||||
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
||||
assert k is not None and v is not None
|
||||
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
||||
key_padding_mask=key_padding_mask,
|
||||
prev_key_padding_mask=prev_key_padding_mask,
|
||||
batch_size=kv_bsz,
|
||||
src_len=k.size(1),
|
||||
static_kv=static_kv,
|
||||
)
|
||||
|
||||
saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
|
||||
saved_state["prev_value"] = v.view(
|
||||
kv_bsz, self.num_heads, -1, self.head_dim
|
||||
)
|
||||
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||||
# In this branch incremental_state is never None
|
||||
assert incremental_state is not None
|
||||
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
||||
assert k is not None
|
||||
assert k.size(1) == src_len
|
||||
|
||||
# This is part of a workaround to get around fork/join parallelism
|
||||
# not supporting Optional types.
|
||||
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||
key_padding_mask = None
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == kv_bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if self.add_zero_attn:
|
||||
assert v is not None
|
||||
src_len += 1
|
||||
k, v, key_padding_mask, attn_mask = self._append_zero_attn(
|
||||
k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask
|
||||
)
|
||||
|
||||
if self.encoder_decoder_attention and bsz != kv_bsz:
|
||||
attn_weights = torch.einsum(
|
||||
"bxhtd,bhsd->bxhts",
|
||||
q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]),
|
||||
k.view((kv_bsz, self.num_heads) + k.size()[1:]),
|
||||
)
|
||||
attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:])
|
||||
else:
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
||||
|
||||
assert list(attn_weights.size()) == [
|
||||
bsz * self.num_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if self.onnx_trace:
|
||||
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
||||
attn_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# don't attend to padding symbols
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
if not is_tpu:
|
||||
attn_weights = attn_weights.view(
|
||||
kv_bsz, -1, self.num_heads, tgt_len, src_len
|
||||
)
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1)
|
||||
.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
.to(torch.bool),
|
||||
float("-inf"),
|
||||
)
|
||||
else:
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if before_softmax:
|
||||
return attn_weights, v
|
||||
|
||||
attn_weights_float = utils.softmax(
|
||||
attn_weights, dim=-1, onnx_trace=self.onnx_trace
|
||||
)
|
||||
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
|
||||
assert v is not None
|
||||
attn: Optional[Tensor] = None
|
||||
if self.encoder_decoder_attention and bsz != kv_bsz:
|
||||
attn = torch.einsum(
|
||||
"bxhts,bhsd->bxhtd",
|
||||
attn_probs.view(
|
||||
(
|
||||
kv_bsz,
|
||||
-1,
|
||||
self.num_heads,
|
||||
)
|
||||
+ attn_probs.size()[1:]
|
||||
),
|
||||
v.view(
|
||||
(
|
||||
kv_bsz,
|
||||
self.num_heads,
|
||||
)
|
||||
+ v.size()[1:]
|
||||
),
|
||||
)
|
||||
attn = attn.reshape((-1,) + attn.size()[-2:])
|
||||
else:
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
assert list(attn.size()) == [
|
||||
bsz * self.num_heads,
|
||||
tgt_len,
|
||||
self.head_dim,
|
||||
]
|
||||
if self.onnx_trace and attn.size(1) == 1:
|
||||
# when ONNX tracing a single decoder step (sequence length == 1)
|
||||
# the transpose is a no-op copy before view, thus unnecessary
|
||||
attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
|
||||
else:
|
||||
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
|
||||
attn = self.out_proj(attn)
|
||||
attn_weights: Optional[Tensor] = None
|
||||
if need_weights:
|
||||
attn_weights = attn_weights_float.view(
|
||||
bsz, self.num_heads, tgt_len, src_len
|
||||
).transpose(1, 0)
|
||||
if not need_head_weights:
|
||||
# average attention weights over heads
|
||||
attn_weights = attn_weights.mean(dim=0)
|
||||
|
||||
return attn, attn_weights
|
||||
|
||||
@staticmethod
|
||||
def _append_prev_key_padding_mask(
|
||||
key_padding_mask: Optional[Tensor],
|
||||
prev_key_padding_mask: Optional[Tensor],
|
||||
batch_size: int,
|
||||
src_len: int,
|
||||
static_kv: bool,
|
||||
) -> Optional[Tensor]:
|
||||
# saved key padding masks have shape (bsz, seq_len)
|
||||
if prev_key_padding_mask is not None and static_kv:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||
new_key_padding_mask = torch.cat(
|
||||
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
||||
)
|
||||
# During incremental decoding, as the padding token enters and
|
||||
# leaves the frame, there will be a time when prev or current
|
||||
# is None
|
||||
elif prev_key_padding_mask is not None:
|
||||
if src_len > prev_key_padding_mask.size(1):
|
||||
filler = torch.zeros(
|
||||
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
||||
device=prev_key_padding_mask.device,
|
||||
)
|
||||
new_key_padding_mask = torch.cat(
|
||||
[prev_key_padding_mask.float(), filler.float()], dim=1
|
||||
)
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask.float()
|
||||
elif key_padding_mask is not None:
|
||||
if src_len > key_padding_mask.size(1):
|
||||
filler = torch.zeros(
|
||||
(batch_size, src_len - key_padding_mask.size(1)),
|
||||
device=key_padding_mask.device,
|
||||
)
|
||||
new_key_padding_mask = torch.cat(
|
||||
[filler.float(), key_padding_mask.float()], dim=1
|
||||
)
|
||||
else:
|
||||
new_key_padding_mask = key_padding_mask.float()
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
return new_key_padding_mask
|
||||
|
||||
@torch.jit.export
|
||||
def reorder_incremental_state(
|
||||
self,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
new_order: Tensor,
|
||||
):
|
||||
"""Reorder buffered internal state (for incremental generation)."""
|
||||
input_buffer = self._get_input_buffer(incremental_state)
|
||||
if input_buffer is not None:
|
||||
for k in input_buffer.keys():
|
||||
input_buffer_k = input_buffer[k]
|
||||
if input_buffer_k is not None:
|
||||
if self.encoder_decoder_attention:
|
||||
if input_buffer_k.size(0) * self.beam_size == new_order.size(0):
|
||||
return incremental_state
|
||||
elif self.beam_size > 1:
|
||||
input_buffer[k] = input_buffer_k.index_select(
|
||||
0,
|
||||
new_order.reshape(-1, self.beam_size)[:, 0]
|
||||
// self.beam_size,
|
||||
)
|
||||
else:
|
||||
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
||||
else:
|
||||
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
||||
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
||||
return incremental_state
|
||||
|
||||
def set_beam_size(self, beam_size):
|
||||
"""Used for effiecient beamable enc-dec attention"""
|
||||
self.beam_size = beam_size
|
||||
|
||||
def _get_input_buffer(
|
||||
self,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
) -> Dict[str, Optional[Tensor]]:
|
||||
result = self.get_incremental_state(incremental_state, "attn_state")
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
empty_result: Dict[str, Optional[Tensor]] = {}
|
||||
return empty_result
|
||||
|
||||
def _set_input_buffer(
|
||||
self,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
buffer: Dict[str, Optional[Tensor]],
|
||||
):
|
||||
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
||||
|
||||
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
||||
return attn_weights
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
prefix = name + "." if name != "" else ""
|
||||
items_to_add = {}
|
||||
keys_to_remove = []
|
||||
for k in state_dict.keys():
|
||||
if k.endswith(prefix + "in_proj_weight"):
|
||||
# in_proj_weight used to be q + k + v with same dimensions
|
||||
dim = int(state_dict[k].shape[0] / 3)
|
||||
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
||||
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
||||
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
||||
|
||||
keys_to_remove.append(k)
|
||||
|
||||
k_bias = prefix + "in_proj_bias"
|
||||
if k_bias in state_dict.keys():
|
||||
dim = int(state_dict[k].shape[0] / 3)
|
||||
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
||||
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
||||
dim : 2 * dim
|
||||
]
|
||||
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
||||
|
||||
keys_to_remove.append(prefix + "in_proj_bias")
|
||||
|
||||
for k in keys_to_remove:
|
||||
del state_dict[k]
|
||||
|
||||
for key, value in items_to_add.items():
|
||||
state_dict[key] = value
|
1
egs/librispeech/SSL/hubert/beam_search.py
Symbolic link
1
egs/librispeech/SSL/hubert/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/beam_search.py
|
367
egs/librispeech/SSL/hubert/dataset.py
Normal file
367
egs/librispeech/SSL/hubert/dataset.py
Normal file
@ -0,0 +1,367 @@
|
||||
# Copyright 2024 Xiaomi Corporation (authors: Yifan Yang)
|
||||
# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from lhotse import validate
|
||||
from lhotse.cut import CutSet
|
||||
from lhotse.dataset.collation import read_audio_from_cuts
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
|
||||
class HubertDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
In this implementation, there will always be a single channel.
|
||||
|
||||
Returns:
|
||||
|
||||
.. code-block::
|
||||
|
||||
{
|
||||
'audio': (B x NumSamples) float tensor
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_sample_size: Optional[int] = None,
|
||||
sample_rate: float = 16000,
|
||||
label_rate: float = 50,
|
||||
random_crop: bool = True,
|
||||
pad_audio: bool = False,
|
||||
num_classes: list = [504],
|
||||
do_normalize: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.sample_rate = sample_rate
|
||||
self.label_rate = label_rate
|
||||
self.random_crop = random_crop
|
||||
self.pad_audio = pad_audio
|
||||
self.num_classes = num_classes
|
||||
self.normalize = do_normalize
|
||||
self.max_sample_size = (
|
||||
max_sample_size if max_sample_size is not None else sys.maxsize
|
||||
)
|
||||
|
||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
||||
self._validate(cuts)
|
||||
audio, _ = read_audio_from_cuts(cuts)
|
||||
for i, item in enumerate(audio):
|
||||
audio[i] = self.postprocess(item, self.sample_rate)
|
||||
audio_lens = [cut.num_samples for cut in cuts]
|
||||
|
||||
if self.pad_audio:
|
||||
audio_size = min(max(audio_lens), self.max_sample_size)
|
||||
else:
|
||||
audio_size = min(min(audio_lens), self.max_sample_size)
|
||||
|
||||
audio, padding_mask, audio_starts = self.collater_audio(
|
||||
audio, audio_lens, audio_size
|
||||
)
|
||||
|
||||
kmeans = [cut.custom["kmeans"] for cut in cuts]
|
||||
kmeans = [
|
||||
torch.tensor([int(item) for item in label.split()], dtype=torch.int64)
|
||||
for label in kmeans
|
||||
]
|
||||
kmeans, _ = self.collater_frm_label(kmeans, audio_size, audio_starts)
|
||||
|
||||
return {
|
||||
"cuts": cuts,
|
||||
"audio": audio,
|
||||
"padding_mask": padding_mask,
|
||||
"kmeans": kmeans,
|
||||
}
|
||||
|
||||
def postprocess(self, wav, cur_sample_rate):
|
||||
if wav.dim() == 2:
|
||||
wav = wav.mean(-1)
|
||||
assert wav.dim() == 1, wav.dim()
|
||||
|
||||
if cur_sample_rate != self.sample_rate:
|
||||
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
||||
|
||||
if self.normalize:
|
||||
with torch.no_grad():
|
||||
wav = F.layer_norm(wav, wav.shape)
|
||||
return wav
|
||||
|
||||
def _validate(self, cuts: CutSet) -> None:
|
||||
validate(cuts)
|
||||
assert all(cut.has_recording for cut in cuts)
|
||||
|
||||
def crop_to_max_size(self, wav, target_size):
|
||||
size = len(wav)
|
||||
diff = size - target_size
|
||||
if diff <= 0:
|
||||
return wav, 0
|
||||
|
||||
start, end = 0, target_size
|
||||
if self.random_crop:
|
||||
start = np.random.randint(0, diff + 1)
|
||||
end = size - diff + start
|
||||
return wav[start:end], start
|
||||
|
||||
def collater_audio(self, audios, audio_lens, audio_size):
|
||||
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
||||
padding_mask = (
|
||||
torch.BoolTensor(collated_audios.shape).fill_(False)
|
||||
# if self.pad_audio else None
|
||||
)
|
||||
audio_starts = [0 for _ in audios]
|
||||
for i, (audio, audio_len) in enumerate(zip(audios, audio_lens)):
|
||||
audio = audio[:audio_len]
|
||||
diff = audio_len - audio_size
|
||||
if diff == 0:
|
||||
collated_audios[i] = audio
|
||||
elif diff < 0:
|
||||
assert self.pad_audio
|
||||
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
||||
padding_mask[i, diff:] = True
|
||||
else:
|
||||
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
|
||||
audio, audio_size
|
||||
)
|
||||
return collated_audios, padding_mask, audio_starts
|
||||
|
||||
def collate_tokens(
|
||||
self,
|
||||
values,
|
||||
pad_idx,
|
||||
eos_idx=None,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
pad_to_length=None,
|
||||
pad_to_multiple=1,
|
||||
pad_to_bsz=None,
|
||||
):
|
||||
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
||||
size = max(v.size(0) for v in values)
|
||||
size = size if pad_to_length is None else max(size, pad_to_length)
|
||||
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
|
||||
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
|
||||
|
||||
batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
|
||||
res = values[0].new(batch_size, size).fill_(pad_idx)
|
||||
|
||||
def copy_tensor(src, dst):
|
||||
assert dst.numel() == src.numel()
|
||||
if move_eos_to_beginning:
|
||||
if eos_idx is None:
|
||||
# if no eos_idx is specified, then use the last token in src
|
||||
dst[0] = src[-1]
|
||||
else:
|
||||
dst[0] = eos_idx
|
||||
dst[1:] = src[:-1]
|
||||
else:
|
||||
dst.copy_(src)
|
||||
|
||||
for i, v in enumerate(values):
|
||||
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
|
||||
return res
|
||||
|
||||
def collater_frm_label(self, targets, audio_size, audio_starts):
|
||||
label_rate = self.label_rate
|
||||
pad = self.num_classes[0] - 1
|
||||
assert label_rate > 0
|
||||
s2f = label_rate / self.sample_rate
|
||||
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
||||
frm_size = int(round(audio_size * s2f))
|
||||
if not self.pad_audio:
|
||||
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
||||
frm_size = min(frm_size, *rem_size)
|
||||
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
|
||||
|
||||
lengths = torch.LongTensor([len(t) for t in targets])
|
||||
targets = self.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
||||
return targets, lengths
|
||||
|
||||
|
||||
class HubertAsrDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
In this implementation, there will always be a single channel.
|
||||
|
||||
Returns:
|
||||
|
||||
.. code-block::
|
||||
|
||||
{
|
||||
'audio': (B x NumSamples) float tensor
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_sample_size: Optional[int] = None,
|
||||
sample_rate: float = 16000,
|
||||
random_crop: bool = True,
|
||||
pad_audio: bool = True,
|
||||
do_normalize: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.sample_rate = sample_rate
|
||||
self.random_crop = random_crop
|
||||
self.pad_audio = pad_audio
|
||||
self.normalize = do_normalize
|
||||
self.max_sample_size = (
|
||||
max_sample_size if max_sample_size is not None else sys.maxsize
|
||||
)
|
||||
|
||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
||||
self._validate(cuts)
|
||||
audio, _ = read_audio_from_cuts(cuts)
|
||||
for i, item in enumerate(audio):
|
||||
audio[i] = self.postprocess(item, self.sample_rate)
|
||||
audio_lens = [cut.num_samples for cut in cuts]
|
||||
if self.pad_audio:
|
||||
audio_size = min(max(audio_lens), self.max_sample_size)
|
||||
else:
|
||||
audio_size = min(min(audio_lens), self.max_sample_size)
|
||||
|
||||
audio, padding_mask, audio_starts = self.collater_audio(
|
||||
audio, audio_lens, audio_size
|
||||
)
|
||||
|
||||
return {
|
||||
"cuts": cuts,
|
||||
"audio": audio,
|
||||
"padding_mask": padding_mask,
|
||||
"supervisions": default_collate(
|
||||
[
|
||||
{
|
||||
"text": supervision.text,
|
||||
}
|
||||
for sequence_idx, cut in enumerate(cuts)
|
||||
for supervision in cut.supervisions
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
def postprocess(self, wav, cur_sample_rate):
|
||||
if wav.dim() == 2:
|
||||
wav = wav.mean(-1)
|
||||
assert wav.dim() == 1, wav.dim()
|
||||
|
||||
if cur_sample_rate != self.sample_rate:
|
||||
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
||||
|
||||
if self.normalize:
|
||||
with torch.no_grad():
|
||||
wav = F.layer_norm(wav, wav.shape)
|
||||
return wav
|
||||
|
||||
def _validate(self, cuts: CutSet) -> None:
|
||||
validate(cuts)
|
||||
assert all(cut.has_recording for cut in cuts)
|
||||
|
||||
def crop_to_max_size(self, wav, target_size):
|
||||
size = len(wav)
|
||||
diff = size - target_size
|
||||
if diff <= 0:
|
||||
return wav, 0
|
||||
|
||||
start, end = 0, target_size
|
||||
if self.random_crop:
|
||||
start = np.random.randint(0, diff + 1)
|
||||
end = size - diff + start
|
||||
return wav[start:end], start
|
||||
|
||||
def collater_audio(self, audios, audio_lens, audio_size):
|
||||
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
||||
padding_mask = (
|
||||
torch.BoolTensor(collated_audios.shape).fill_(False)
|
||||
# if self.pad_audio else None
|
||||
)
|
||||
audio_starts = [0 for _ in audios]
|
||||
for i, (audio, audio_len) in enumerate(zip(audios, audio_lens)):
|
||||
audio = audio[:audio_len]
|
||||
diff = audio_len - audio_size
|
||||
if diff == 0:
|
||||
collated_audios[i] = audio
|
||||
elif diff < 0:
|
||||
assert self.pad_audio
|
||||
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
||||
padding_mask[i, diff:] = True
|
||||
else:
|
||||
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
|
||||
audio, audio_size
|
||||
)
|
||||
return collated_audios, padding_mask, audio_starts
|
||||
|
||||
def collate_tokens(
|
||||
self,
|
||||
values,
|
||||
pad_idx,
|
||||
eos_idx=None,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
pad_to_length=None,
|
||||
pad_to_multiple=1,
|
||||
pad_to_bsz=None,
|
||||
):
|
||||
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
||||
size = max(v.size(0) for v in values)
|
||||
size = size if pad_to_length is None else max(size, pad_to_length)
|
||||
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
|
||||
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
|
||||
|
||||
batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
|
||||
res = values[0].new(batch_size, size).fill_(pad_idx)
|
||||
|
||||
def copy_tensor(src, dst):
|
||||
assert dst.numel() == src.numel()
|
||||
if move_eos_to_beginning:
|
||||
if eos_idx is None:
|
||||
# if no eos_idx is specified, then use the last token in src
|
||||
dst[0] = src[-1]
|
||||
else:
|
||||
dst[0] = eos_idx
|
||||
dst[1:] = src[:-1]
|
||||
else:
|
||||
dst.copy_(src)
|
||||
|
||||
for i, v in enumerate(values):
|
||||
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lhotse import load_manifest_lazy
|
||||
from lhotse.dataset import DynamicBucketingSampler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
dataset = HubertDataset()
|
||||
cuts = load_manifest_lazy("data/fbank2/librispeech_cuts_train-clean-100.jsonl.gz")
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=100,
|
||||
shuffle=False,
|
||||
)
|
||||
dl = DataLoader(
|
||||
dataset,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=2,
|
||||
)
|
||||
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
print(batch)
|
||||
break
|
1045
egs/librispeech/SSL/hubert/decode.py
Normal file
1045
egs/librispeech/SSL/hubert/decode.py
Normal file
File diff suppressed because it is too large
Load Diff
1045
egs/librispeech/SSL/hubert/decode_ce.py
Normal file
1045
egs/librispeech/SSL/hubert/decode_ce.py
Normal file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/SSL/hubert/decoder.py
Symbolic link
1
egs/librispeech/SSL/hubert/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/decoder.py
|
1254
egs/librispeech/SSL/hubert/finetune.py
Normal file
1254
egs/librispeech/SSL/hubert/finetune.py
Normal file
File diff suppressed because it is too large
Load Diff
1254
egs/librispeech/SSL/hubert/finetune_ce.py
Normal file
1254
egs/librispeech/SSL/hubert/finetune_ce.py
Normal file
File diff suppressed because it is too large
Load Diff
984
egs/librispeech/SSL/hubert/hubert.py
Normal file
984
egs/librispeech/SSL/hubert/hubert.py
Normal file
@ -0,0 +1,984 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from utils import GradMultiply, LayerNorm
|
||||
from wav2vec2_module import ConvFeatureExtractionModel, TransformerEncoder
|
||||
|
||||
|
||||
def compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
mask_type: str = "static",
|
||||
mask_other: float = 0.0,
|
||||
min_masks: int = 0,
|
||||
no_overlap: bool = False,
|
||||
min_space: int = 0,
|
||||
require_same_masks: bool = True,
|
||||
mask_dropout: float = 0.0,
|
||||
add_masks: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
epoch: Optional[int] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
|
||||
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes random mask spans for a given shape
|
||||
|
||||
Args:
|
||||
shape: the the shape for which to compute masks.
|
||||
should be of size 2 where first element is batch size and 2nd is timesteps
|
||||
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||||
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||||
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||
mask_type: how to compute mask lengths
|
||||
static = fixed size
|
||||
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
||||
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
||||
poisson = sample from possion distribution with lambda = mask length
|
||||
min_masks: minimum number of masked spans
|
||||
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
||||
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
||||
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
||||
mask_dropout: randomly dropout this percentage of masks in each example
|
||||
"""
|
||||
|
||||
bsz, all_sz = shape
|
||||
mask = np.full((bsz, all_sz), False)
|
||||
|
||||
if num_mask_ver == 1:
|
||||
all_num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * all_sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
all_num_mask = max(min_masks, all_num_mask)
|
||||
|
||||
mask_idcs = []
|
||||
for i in range(bsz):
|
||||
if seed is not None and epoch is not None and indices is not None:
|
||||
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
|
||||
else:
|
||||
seed_i = None
|
||||
|
||||
rng = np.random.default_rng(seed_i)
|
||||
|
||||
if padding_mask is not None:
|
||||
sz = all_sz - padding_mask[i].long().sum().item()
|
||||
assert sz >= 0, sz
|
||||
else:
|
||||
sz = all_sz
|
||||
|
||||
if num_mask_ver == 1:
|
||||
if padding_mask is not None:
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
num_mask = all_num_mask
|
||||
elif num_mask_ver == 2:
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ rng.random()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
if mask_type == "static":
|
||||
lengths = np.full(num_mask, mask_length)
|
||||
elif mask_type == "uniform":
|
||||
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
||||
elif mask_type == "normal":
|
||||
lengths = rng.normal(mask_length, mask_other, size=num_mask)
|
||||
lengths = [max(1, int(round(x))) for x in lengths]
|
||||
elif mask_type == "poisson":
|
||||
lengths = rng.poisson(mask_length, size=num_mask)
|
||||
lengths = [int(round(x)) for x in lengths]
|
||||
else:
|
||||
raise Exception("unknown mask selection " + mask_type)
|
||||
|
||||
if sum(lengths) == 0:
|
||||
if mask_type == "static":
|
||||
raise ValueError(f"this should never happens")
|
||||
else:
|
||||
lengths = [min(mask_length, sz - 1)]
|
||||
|
||||
if no_overlap:
|
||||
mask_idc = []
|
||||
|
||||
def arrange(s, e, length, keep_length):
|
||||
span_start = rng.randint(s, e - length)
|
||||
mask_idc.extend(span_start + i for i in range(length))
|
||||
|
||||
new_parts = []
|
||||
if span_start - s - min_space >= keep_length:
|
||||
new_parts.append((s, span_start - min_space + 1))
|
||||
if e - span_start - length - min_space > keep_length:
|
||||
new_parts.append((span_start + length + min_space, e))
|
||||
return new_parts
|
||||
|
||||
parts = [(0, sz)]
|
||||
min_length = min(lengths)
|
||||
for length in sorted(lengths, reverse=True):
|
||||
lens = np.fromiter(
|
||||
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
||||
np.int,
|
||||
)
|
||||
l_sum = np.sum(lens)
|
||||
if l_sum == 0:
|
||||
break
|
||||
probs = lens / np.sum(lens)
|
||||
c = rng.choice(len(parts), p=probs)
|
||||
s, e = parts.pop(c)
|
||||
parts.extend(arrange(s, e, length, min_length))
|
||||
mask_idc = np.asarray(mask_idc)
|
||||
else:
|
||||
if idc_select_ver == 1:
|
||||
min_len = min(lengths)
|
||||
if sz - min_len <= num_mask:
|
||||
min_len = sz - num_mask - 1
|
||||
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
|
||||
elif idc_select_ver == 2:
|
||||
mask_idc = rng.choice(sz, num_mask, replace=False)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
mask_idc = np.asarray(
|
||||
[
|
||||
mask_idc[j] + offset
|
||||
for j in range(len(mask_idc))
|
||||
for offset in range(lengths[j])
|
||||
]
|
||||
)
|
||||
|
||||
mask_idc = np.unique(mask_idc[mask_idc < sz])
|
||||
if len(mask_idc) >= sz:
|
||||
raise ValueError(
|
||||
(
|
||||
f"the entire sequence is masked. "
|
||||
f"sz={sz}; mask_idc[mask_idc]; "
|
||||
f"index={indices[i] if indices is not None else None}"
|
||||
)
|
||||
)
|
||||
mask_idcs.append(mask_idc)
|
||||
|
||||
target_len = None
|
||||
if require_same_masks:
|
||||
if add_masks:
|
||||
target_len = max([len(m) for m in mask_idcs])
|
||||
else:
|
||||
target_len = min([len(m) for m in mask_idcs])
|
||||
|
||||
for i, mask_idc in enumerate(mask_idcs):
|
||||
if target_len is not None and len(mask_idc) > target_len:
|
||||
mask_idc = rng.choice(mask_idc, target_len, replace=False)
|
||||
|
||||
mask[i, mask_idc] = True
|
||||
|
||||
if target_len is not None and len(mask_idc) < target_len:
|
||||
unmasked = np.flatnonzero(~mask[i])
|
||||
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
|
||||
mask[i, to_mask] = True
|
||||
|
||||
if mask_dropout > 0:
|
||||
masked = np.flatnonzero(mask[i])
|
||||
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
|
||||
to_drop = rng.choice(masked, num_holes, replace=False)
|
||||
mask[i, to_drop] = False
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def add_hubert_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--label-rate",
|
||||
type=float,
|
||||
default=50,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=float,
|
||||
default=16000,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--extractor-mode",
|
||||
type=str,
|
||||
default="default",
|
||||
help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group
|
||||
norm with d groups in the first conv block, whereas layer_norm
|
||||
has layer norms in every block (meant to use with normalize=True)""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder-layers",
|
||||
type=int,
|
||||
default=12,
|
||||
help="num encoder layers in the transformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-embed-dim",
|
||||
type=int,
|
||||
default=768,
|
||||
help="encoder embedding dimension",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-ffn-embed-dim",
|
||||
type=int,
|
||||
default=3072,
|
||||
help="encoder embedding dimension for FFN",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-attention-heads",
|
||||
type=int,
|
||||
default=12,
|
||||
help="num encoder attention heads",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--activation-fn",
|
||||
type=str,
|
||||
choices=[
|
||||
"relu",
|
||||
"gelu",
|
||||
"gelu_fast",
|
||||
"gelu_accurate",
|
||||
"tanh",
|
||||
"linear",
|
||||
],
|
||||
default="gelu",
|
||||
help="activation function to use",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--layer-type",
|
||||
type=str,
|
||||
choices=["transformer", "conformer", "trf_adp"],
|
||||
default="transformer",
|
||||
help="layer type in encoder",
|
||||
)
|
||||
|
||||
# dropouts
|
||||
parser.add_argument(
|
||||
"--dropout",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="dropout probability for the transformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-dropout",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="dropout probability for attention weights",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--activation-dropout",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="dropout probability after activation in FFN",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-layerdrop",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="probability of dropping a tarnsformer layer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dropout-input",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="dropout to apply to the input (after feat extr)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dropout-features",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="dropout to apply to the features (after feat extr)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--final-dim",
|
||||
type=int,
|
||||
default=0,
|
||||
help="project final representations and targets to this many dimensions. set to encoder_embed_dim is <= 0",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--untie-final-proj",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use separate projection for each target",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--layer-norm-first",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="apply layernorm first in the transformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conv-feature-layers",
|
||||
type=str,
|
||||
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
|
||||
help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conv-bias",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="include bias in conv encoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--logit-temp",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="temperature to divide logits by",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-glu",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="adds projection + glu to targets",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feature-grad-mult",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="multiply feature extractor var grads by this",
|
||||
)
|
||||
|
||||
# masking
|
||||
parser.add_argument("--mask-length", type=int, default=10, help="mask_length")
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-prob",
|
||||
type=float,
|
||||
default=0.65,
|
||||
help="probability of replacing a token with mask",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-selection",
|
||||
type=str,
|
||||
choices=["static", "uniform", "normal", "poisson"],
|
||||
default="static",
|
||||
help="how to choose mask length",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-other",
|
||||
type=float,
|
||||
default=0,
|
||||
help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-mask-overlap",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="whether to allow masks to overlap",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-min-space",
|
||||
type=int,
|
||||
default=1,
|
||||
help="min space between spans (if no overlap is enabled)",
|
||||
)
|
||||
|
||||
# channel masking
|
||||
parser.add_argument(
|
||||
"--mask-channel-length",
|
||||
type=int,
|
||||
default=10,
|
||||
help="length of the mask for features (channels)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-channel-prob",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="probability of replacing a feature with 0",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-channel-selection",
|
||||
type=str,
|
||||
choices=["static", "uniform", "normal", "poisson"],
|
||||
default="static",
|
||||
help="how to choose mask length for channel masking",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-channel-other",
|
||||
type=float,
|
||||
default=0,
|
||||
help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-mask-channel-overlap",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="whether to allow channel masks to overlap",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-channel-min-space",
|
||||
type=int,
|
||||
default=1,
|
||||
help="min space between spans (if no overlap is enabled)",
|
||||
)
|
||||
|
||||
# positional embeddings
|
||||
parser.add_argument(
|
||||
"--conv-pos",
|
||||
type=int,
|
||||
default=128,
|
||||
help="number of filters for convolutional positional embeddings",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conv-pos-groups",
|
||||
type=int,
|
||||
default=16,
|
||||
help="number of groups for convolutional positional embedding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conv-pos-batch-norm",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use batch norm instead of weight norm in conv_pos (for bf16 models)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--latent-temp",
|
||||
type=float,
|
||||
nargs="*",
|
||||
default=[2, 0.5, 0.999995],
|
||||
help="legacy (to be removed)",
|
||||
)
|
||||
|
||||
# loss computation
|
||||
parser.add_argument(
|
||||
"--skip-masked",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="skip computing losses over masked frames",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip-nomask",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="skip computing losses over unmasked frames",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-activations",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="recompute activations and save memory for extra compute",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pred-masked-weight",
|
||||
type=float,
|
||||
default=1,
|
||||
help="weight for masked part in ssl loss",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pred-nomask-weight",
|
||||
type=float,
|
||||
default=0,
|
||||
help="weight for masked part in ssl loss",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--loss-weights",
|
||||
type=float,
|
||||
nargs="*",
|
||||
default=[10],
|
||||
help="weight for masked part in ssl loss",
|
||||
)
|
||||
|
||||
# FP16 optimization
|
||||
parser.add_argument(
|
||||
"--required-seq-len-multiple",
|
||||
type=int,
|
||||
default=2,
|
||||
help="pad the input to encoder such that the sequence length is divisible by multiple",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attn-type", type=str, default="", help="if espnet use ESPNET MHA"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-enc-type",
|
||||
type=str,
|
||||
default="abs",
|
||||
help="Positional encoding type to use in conformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-classes",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[504],
|
||||
help="""num class, a little larger than the number of cluster,
|
||||
the largest is for padding,
|
||||
and the value should be the multiple of 4, for faster computation""",
|
||||
)
|
||||
|
||||
|
||||
class HubertModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
|
||||
self.embed = feature_enc_layers[-1][0]
|
||||
|
||||
self.feature_extractor = ConvFeatureExtractionModel(
|
||||
conv_layers=feature_enc_layers,
|
||||
dropout=0.0,
|
||||
mode=cfg.extractor_mode,
|
||||
conv_bias=cfg.conv_bias,
|
||||
)
|
||||
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
||||
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate
|
||||
|
||||
self.post_extract_proj = (
|
||||
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
||||
if self.embed != cfg.encoder_embed_dim
|
||||
else None
|
||||
)
|
||||
|
||||
self.mask_prob = cfg.mask_prob
|
||||
self.mask_selection = cfg.mask_selection
|
||||
self.mask_other = cfg.mask_other
|
||||
self.mask_length = cfg.mask_length
|
||||
self.no_mask_overlap = cfg.no_mask_overlap
|
||||
self.mask_min_space = cfg.mask_min_space
|
||||
|
||||
self.mask_channel_prob = cfg.mask_channel_prob
|
||||
self.mask_channel_selection = cfg.mask_channel_selection
|
||||
self.mask_channel_other = cfg.mask_channel_other
|
||||
self.mask_channel_length = cfg.mask_channel_length
|
||||
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
||||
self.mask_channel_min_space = cfg.mask_channel_min_space
|
||||
|
||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
||||
|
||||
self.feature_grad_mult = cfg.feature_grad_mult
|
||||
self.logit_temp = cfg.logit_temp
|
||||
self.skip_masked = cfg.skip_masked
|
||||
self.skip_nomask = cfg.skip_nomask
|
||||
|
||||
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
||||
|
||||
self.mask_emb = nn.Parameter(
|
||||
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
||||
)
|
||||
|
||||
self.encoder = TransformerEncoder(cfg)
|
||||
self.layer_norm = LayerNorm(self.embed)
|
||||
|
||||
self.target_glu = None
|
||||
if cfg.target_glu:
|
||||
self.target_glu = nn.Sequential(
|
||||
nn.Linear(final_dim, final_dim * 2), nn.GLU()
|
||||
)
|
||||
|
||||
self.untie_final_proj = cfg.untie_final_proj
|
||||
if self.untie_final_proj:
|
||||
self.final_proj = nn.Linear(
|
||||
cfg.encoder_embed_dim, final_dim * len(cfg.num_classes)
|
||||
)
|
||||
else:
|
||||
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
|
||||
|
||||
# modules below are not needed during fine-tuning
|
||||
self.num_classes = cfg.num_classes
|
||||
self.label_embs_concat = nn.Parameter(
|
||||
torch.FloatTensor(sum(self.num_classes), final_dim)
|
||||
)
|
||||
self.pred_masked_weight = cfg.pred_masked_weight
|
||||
self.pred_nomask_weight = cfg.pred_nomask_weight
|
||||
self.loss_weights = cfg.loss_weights
|
||||
nn.init.uniform_(self.label_embs_concat)
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
||||
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
return state_dict
|
||||
|
||||
def apply_mask(self, x, padding_mask, target_list):
|
||||
B, T, C = x.shape
|
||||
if self.mask_prob > 0:
|
||||
mask_indices = compute_mask_indices(
|
||||
(B, T),
|
||||
padding_mask,
|
||||
self.mask_prob,
|
||||
self.mask_length,
|
||||
self.mask_selection,
|
||||
self.mask_other,
|
||||
min_masks=2,
|
||||
no_overlap=self.no_mask_overlap,
|
||||
min_space=self.mask_min_space,
|
||||
)
|
||||
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
||||
x[mask_indices] = self.mask_emb.to(x.dtype)
|
||||
else:
|
||||
mask_indices = None
|
||||
|
||||
if self.mask_channel_prob > 0:
|
||||
mask_channel_indices = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
self.mask_channel_prob,
|
||||
self.mask_channel_length,
|
||||
self.mask_channel_selection,
|
||||
self.mask_channel_other,
|
||||
no_overlap=self.no_mask_channel_overlap,
|
||||
min_space=self.mask_channel_min_space,
|
||||
)
|
||||
mask_channel_indices = (
|
||||
torch.from_numpy(mask_channel_indices)
|
||||
.to(x.device)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, T, -1)
|
||||
)
|
||||
x[mask_channel_indices] = 0
|
||||
|
||||
return x, mask_indices
|
||||
|
||||
def compute_nce(self, x, pos, negs):
|
||||
neg_is_pos = (pos == negs).all(-1)
|
||||
pos = pos.unsqueeze(0)
|
||||
targets = torch.cat([pos, negs], dim=0)
|
||||
|
||||
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
|
||||
logits /= self.logit_temp
|
||||
if neg_is_pos.any():
|
||||
logits[1:][neg_is_pos] = float("-inf")
|
||||
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
|
||||
return logits
|
||||
|
||||
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
|
||||
if self.feature_grad_mult > 0:
|
||||
features = self.feature_extractor(source)
|
||||
if self.feature_grad_mult != 1.0:
|
||||
features = GradMultiply.apply(features, self.feature_grad_mult)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
features = self.feature_extractor(source)
|
||||
return features
|
||||
|
||||
def forward_targets(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
target_list: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Trim features to ensure labels exist and then get aligned labels
|
||||
feat_tsz = features.size(2)
|
||||
targ_tsz = min([t.size(1) for t in target_list])
|
||||
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
||||
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
||||
features = features[..., :feat_tsz]
|
||||
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
||||
target_list = [t[:, target_inds.long()] for t in target_list]
|
||||
return features, target_list
|
||||
|
||||
def forward_padding_mask(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
extra = padding_mask.size(1) % features.size(1)
|
||||
if extra > 0:
|
||||
padding_mask = padding_mask[:, :-extra]
|
||||
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
||||
padding_mask = padding_mask.all(-1)
|
||||
return padding_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
target_list: Optional[List[torch.Tensor]] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
mask: bool = True,
|
||||
features_only: bool = False,
|
||||
output_layer: Optional[int] = None,
|
||||
):
|
||||
"""output layer is 1-based"""
|
||||
features = self.forward_features(source)
|
||||
if target_list is not None:
|
||||
features, target_list = self.forward_targets(features, target_list)
|
||||
|
||||
features_pen = features.float().pow(2).mean()
|
||||
|
||||
features = features.transpose(1, 2)
|
||||
features = self.layer_norm(features)
|
||||
unmasked_features = features.clone()
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||
|
||||
if self.post_extract_proj is not None:
|
||||
features = self.post_extract_proj(features)
|
||||
|
||||
features = self.dropout_input(features)
|
||||
unmasked_features = self.dropout_features(unmasked_features)
|
||||
|
||||
if mask:
|
||||
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
|
||||
else:
|
||||
x = features
|
||||
mask_indices = None
|
||||
|
||||
# feature: (B, T, D), float
|
||||
# target: (B, T), long
|
||||
# x: (B, T, D), float
|
||||
# padding_mask: (B, T), bool
|
||||
# mask_indices: (B, T), bool
|
||||
x, _ = self.encoder(
|
||||
x,
|
||||
padding_mask=padding_mask,
|
||||
layer=None if output_layer is None else output_layer - 1,
|
||||
)
|
||||
|
||||
if features_only:
|
||||
return {"x": x, "padding_mask": padding_mask, "features": features}
|
||||
|
||||
def compute_pred(proj_x, target, label_embs):
|
||||
# compute logits for the i-th label set
|
||||
y = torch.index_select(label_embs, 0, target.long())
|
||||
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
|
||||
if self.target_glu:
|
||||
y = self.target_glu(y)
|
||||
negs = self.target_glu(negs)
|
||||
# proj_x: (S, D)
|
||||
# y: (S, D)
|
||||
# negs: (Neg, S, D)
|
||||
return self.compute_nce(proj_x, y, negs)
|
||||
|
||||
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
|
||||
|
||||
if not self.skip_masked:
|
||||
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
||||
proj_x_m = self.final_proj(x[masked_indices])
|
||||
if self.untie_final_proj:
|
||||
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
|
||||
else:
|
||||
proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
|
||||
logit_m_list = [
|
||||
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
|
||||
for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
|
||||
]
|
||||
else:
|
||||
logit_m_list = [None for _ in target_list]
|
||||
|
||||
if not self.skip_nomask:
|
||||
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
|
||||
proj_x_u = self.final_proj(x[nomask_indices])
|
||||
if self.untie_final_proj:
|
||||
proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
|
||||
else:
|
||||
proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
|
||||
|
||||
logit_u_list = [
|
||||
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
|
||||
for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list))
|
||||
]
|
||||
else:
|
||||
logit_u_list = [None for _ in target_list]
|
||||
|
||||
# result = {
|
||||
# "logit_m_list": logit_m_list,
|
||||
# "logit_u_list": logit_u_list,
|
||||
# "padding_mask": padding_mask,
|
||||
# "features_pen": features_pen,
|
||||
# }
|
||||
return self.compute_loss(logit_m_list, logit_u_list, features_pen)
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
mask: bool = False,
|
||||
ret_conv: bool = False,
|
||||
output_layer: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
res = self.forward(
|
||||
source,
|
||||
padding_mask=padding_mask,
|
||||
mask=mask,
|
||||
features_only=True,
|
||||
output_layer=output_layer,
|
||||
)
|
||||
feature = res["features"] if ret_conv else res["x"]
|
||||
return feature, res["padding_mask"]
|
||||
|
||||
def get_logits(self, net_output, is_masked=True):
|
||||
if is_masked:
|
||||
logits_list = net_output["logit_m_list"]
|
||||
else:
|
||||
logits_list = net_output["logit_u_list"]
|
||||
logits_list = [x.float() for x in logits_list if x is not None]
|
||||
return logits_list
|
||||
|
||||
def get_targets(self, net_output, is_masked=True):
|
||||
logits_list = self.get_logits(net_output, is_masked)
|
||||
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
|
||||
return targets_list
|
||||
|
||||
def get_extra_losses(self, net_output):
|
||||
extra_losses = []
|
||||
names = []
|
||||
|
||||
if "features_pen" in net_output:
|
||||
extra_losses.append(net_output["features_pen"])
|
||||
names.append("features_pen")
|
||||
|
||||
return extra_losses, names
|
||||
|
||||
def remove_pretraining_modules(self):
|
||||
self.target_glu = None
|
||||
self.final_proj = None
|
||||
|
||||
def compute_loss(self, logit_m_list, logit_u_list, features_pen):
|
||||
loss = 0.0
|
||||
sample_size = 0
|
||||
logging_output = {}
|
||||
reduce = True
|
||||
reduction = "sum" if reduce else "none"
|
||||
|
||||
loss_m_list = []
|
||||
logp_m_list = [x.float() for x in logit_m_list if x is not None]
|
||||
targ_m_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logp_m_list]
|
||||
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
|
||||
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
|
||||
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
|
||||
loss_m_list.append(loss_m)
|
||||
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
|
||||
if self.pred_masked_weight > 0:
|
||||
loss += self.pred_masked_weight * sum(loss_m_list)
|
||||
sample_size += targ_m_list[0].numel()
|
||||
|
||||
loss_u_list = []
|
||||
logp_u_list = [x.float() for x in logit_u_list if x is not None]
|
||||
targ_u_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logp_u_list]
|
||||
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
|
||||
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
|
||||
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
|
||||
loss_u_list.append(loss_u)
|
||||
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
|
||||
if self.pred_nomask_weight > 0:
|
||||
loss += self.pred_nomask_weight * sum(loss_u_list)
|
||||
sample_size += targ_u_list[0].numel()
|
||||
|
||||
if self.loss_weights is not None:
|
||||
extra_losses = []
|
||||
names = []
|
||||
extra_losses.append(features_pen)
|
||||
names.append("features_pen")
|
||||
if torch.is_tensor(extra_losses):
|
||||
extra_losses = [extra_losses]
|
||||
names = [names]
|
||||
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
||||
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
||||
assert len(extra_losses) == len(
|
||||
self.loss_weights
|
||||
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
||||
for p, n, coef in zip(extra_losses, names, self.loss_weights):
|
||||
if coef != 0 and p is not None:
|
||||
p = coef * p.float() * sample_size
|
||||
loss += p
|
||||
logging_output[f"loss_{n}"] = p.item()
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.item() if reduce else loss,
|
||||
**logging_output,
|
||||
}
|
||||
|
||||
# for lk in self.log_keys:
|
||||
# if lk in net_output:
|
||||
# logging_output[lk] = float((net_output[lk]))
|
||||
|
||||
def compute_correct(logits):
|
||||
if logits.numel() == 0:
|
||||
return 0, 0
|
||||
else:
|
||||
assert logits.dim() > 1, logits.shape
|
||||
max = logits.argmax(-1) == 0
|
||||
min = logits.argmin(-1) == 0
|
||||
both = max & min
|
||||
corr = max.long().sum().item() - both.long().sum().item()
|
||||
count = max.numel()
|
||||
return corr, count
|
||||
|
||||
with torch.no_grad():
|
||||
for i, logp_m in enumerate(logp_m_list):
|
||||
corr_m, count_m = compute_correct(logp_m)
|
||||
logging_output[f"correct_m_{i}"] = corr_m
|
||||
logging_output[f"count_m_{i}"] = count_m
|
||||
|
||||
for i, logp_u in enumerate(logp_u_list):
|
||||
corr_u, count_u = compute_correct(logp_u)
|
||||
logging_output[f"correct_u_{i}"] = corr_u
|
||||
logging_output[f"count_u_{i}"] = count_u
|
||||
|
||||
return loss, sample_size, logging_output
|
940
egs/librispeech/SSL/hubert/hubert_ce.py
Normal file
940
egs/librispeech/SSL/hubert/hubert_ce.py
Normal file
@ -0,0 +1,940 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from utils import GradMultiply, LayerNorm
|
||||
from wav2vec2_module import ConvFeatureExtractionModel, TransformerEncoder
|
||||
|
||||
|
||||
def compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
mask_type: str = "static",
|
||||
mask_other: float = 0.0,
|
||||
min_masks: int = 0,
|
||||
no_overlap: bool = False,
|
||||
min_space: int = 0,
|
||||
require_same_masks: bool = True,
|
||||
mask_dropout: float = 0.0,
|
||||
add_masks: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
epoch: Optional[int] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
|
||||
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes random mask spans for a given shape
|
||||
|
||||
Args:
|
||||
shape: the the shape for which to compute masks.
|
||||
should be of size 2 where first element is batch size and 2nd is timesteps
|
||||
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||||
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||||
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||
mask_type: how to compute mask lengths
|
||||
static = fixed size
|
||||
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
||||
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
||||
poisson = sample from possion distribution with lambda = mask length
|
||||
min_masks: minimum number of masked spans
|
||||
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
||||
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
||||
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
||||
mask_dropout: randomly dropout this percentage of masks in each example
|
||||
"""
|
||||
|
||||
bsz, all_sz = shape
|
||||
mask = np.full((bsz, all_sz), False)
|
||||
|
||||
if num_mask_ver == 1:
|
||||
all_num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * all_sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
all_num_mask = max(min_masks, all_num_mask)
|
||||
|
||||
mask_idcs = []
|
||||
for i in range(bsz):
|
||||
if seed is not None and epoch is not None and indices is not None:
|
||||
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
|
||||
else:
|
||||
seed_i = None
|
||||
|
||||
rng = np.random.default_rng(seed_i)
|
||||
|
||||
if padding_mask is not None:
|
||||
sz = all_sz - padding_mask[i].long().sum().item()
|
||||
assert sz >= 0, sz
|
||||
else:
|
||||
sz = all_sz
|
||||
|
||||
if num_mask_ver == 1:
|
||||
if padding_mask is not None:
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
num_mask = all_num_mask
|
||||
elif num_mask_ver == 2:
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ rng.random()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
if mask_type == "static":
|
||||
lengths = np.full(num_mask, mask_length)
|
||||
elif mask_type == "uniform":
|
||||
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
||||
elif mask_type == "normal":
|
||||
lengths = rng.normal(mask_length, mask_other, size=num_mask)
|
||||
lengths = [max(1, int(round(x))) for x in lengths]
|
||||
elif mask_type == "poisson":
|
||||
lengths = rng.poisson(mask_length, size=num_mask)
|
||||
lengths = [int(round(x)) for x in lengths]
|
||||
else:
|
||||
raise Exception("unknown mask selection " + mask_type)
|
||||
|
||||
if sum(lengths) == 0:
|
||||
if mask_type == "static":
|
||||
raise ValueError(f"this should never happens")
|
||||
else:
|
||||
lengths = [min(mask_length, sz - 1)]
|
||||
|
||||
if no_overlap:
|
||||
mask_idc = []
|
||||
|
||||
def arrange(s, e, length, keep_length):
|
||||
span_start = rng.randint(s, e - length)
|
||||
mask_idc.extend(span_start + i for i in range(length))
|
||||
|
||||
new_parts = []
|
||||
if span_start - s - min_space >= keep_length:
|
||||
new_parts.append((s, span_start - min_space + 1))
|
||||
if e - span_start - length - min_space > keep_length:
|
||||
new_parts.append((span_start + length + min_space, e))
|
||||
return new_parts
|
||||
|
||||
parts = [(0, sz)]
|
||||
min_length = min(lengths)
|
||||
for length in sorted(lengths, reverse=True):
|
||||
lens = np.fromiter(
|
||||
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
||||
np.int,
|
||||
)
|
||||
l_sum = np.sum(lens)
|
||||
if l_sum == 0:
|
||||
break
|
||||
probs = lens / np.sum(lens)
|
||||
c = rng.choice(len(parts), p=probs)
|
||||
s, e = parts.pop(c)
|
||||
parts.extend(arrange(s, e, length, min_length))
|
||||
mask_idc = np.asarray(mask_idc)
|
||||
else:
|
||||
if idc_select_ver == 1:
|
||||
min_len = min(lengths)
|
||||
if sz - min_len <= num_mask:
|
||||
min_len = sz - num_mask - 1
|
||||
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
|
||||
elif idc_select_ver == 2:
|
||||
mask_idc = rng.choice(sz, num_mask, replace=False)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
mask_idc = np.asarray(
|
||||
[
|
||||
mask_idc[j] + offset
|
||||
for j in range(len(mask_idc))
|
||||
for offset in range(lengths[j])
|
||||
]
|
||||
)
|
||||
|
||||
mask_idc = np.unique(mask_idc[mask_idc < sz])
|
||||
if len(mask_idc) >= sz:
|
||||
raise ValueError(
|
||||
(
|
||||
f"the entire sequence is masked. "
|
||||
f"sz={sz}; mask_idc[mask_idc]; "
|
||||
f"index={indices[i] if indices is not None else None}"
|
||||
)
|
||||
)
|
||||
mask_idcs.append(mask_idc)
|
||||
|
||||
target_len = None
|
||||
if require_same_masks:
|
||||
if add_masks:
|
||||
target_len = max([len(m) for m in mask_idcs])
|
||||
else:
|
||||
target_len = min([len(m) for m in mask_idcs])
|
||||
|
||||
for i, mask_idc in enumerate(mask_idcs):
|
||||
if target_len is not None and len(mask_idc) > target_len:
|
||||
mask_idc = rng.choice(mask_idc, target_len, replace=False)
|
||||
|
||||
mask[i, mask_idc] = True
|
||||
|
||||
if target_len is not None and len(mask_idc) < target_len:
|
||||
unmasked = np.flatnonzero(~mask[i])
|
||||
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
|
||||
mask[i, to_mask] = True
|
||||
|
||||
if mask_dropout > 0:
|
||||
masked = np.flatnonzero(mask[i])
|
||||
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
|
||||
to_drop = rng.choice(masked, num_holes, replace=False)
|
||||
mask[i, to_drop] = False
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def add_hubert_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--label-rate",
|
||||
type=float,
|
||||
default=50,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=float,
|
||||
default=16000,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--extractor-mode",
|
||||
type=str,
|
||||
default="default",
|
||||
help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group
|
||||
norm with d groups in the first conv block, whereas layer_norm
|
||||
has layer norms in every block (meant to use with normalize=True)""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder-layers",
|
||||
type=int,
|
||||
default=12,
|
||||
help="num encoder layers in the transformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-embed-dim",
|
||||
type=int,
|
||||
default=768,
|
||||
help="encoder embedding dimension",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-ffn-embed-dim",
|
||||
type=int,
|
||||
default=3072,
|
||||
help="encoder embedding dimension for FFN",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-attention-heads",
|
||||
type=int,
|
||||
default=12,
|
||||
help="num encoder attention heads",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--activation-fn",
|
||||
type=str,
|
||||
choices=[
|
||||
"relu",
|
||||
"gelu",
|
||||
"gelu_fast",
|
||||
"gelu_accurate",
|
||||
"tanh",
|
||||
"linear",
|
||||
],
|
||||
default="gelu",
|
||||
help="activation function to use",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--layer-type",
|
||||
type=str,
|
||||
choices=["transformer", "conformer", "trf_adp"],
|
||||
default="transformer",
|
||||
help="layer type in encoder",
|
||||
)
|
||||
|
||||
# dropouts
|
||||
parser.add_argument(
|
||||
"--dropout",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="dropout probability for the transformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-dropout",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="dropout probability for attention weights",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--activation-dropout",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="dropout probability after activation in FFN",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-layerdrop",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="probability of dropping a tarnsformer layer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dropout-input",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="dropout to apply to the input (after feat extr)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dropout-features",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="dropout to apply to the features (after feat extr)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--final-dim",
|
||||
type=int,
|
||||
default=0,
|
||||
help="project final representations and targets to this many dimensions. set to encoder_embed_dim is <= 0",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--untie-final-proj",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use separate projection for each target",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--layer-norm-first",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="apply layernorm first in the transformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conv-feature-layers",
|
||||
type=str,
|
||||
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
|
||||
help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conv-bias",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="include bias in conv encoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--logit-temp",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="temperature to divide logits by",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-glu",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="adds projection + glu to targets",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feature-grad-mult",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="multiply feature extractor var grads by this",
|
||||
)
|
||||
|
||||
# masking
|
||||
parser.add_argument("--mask-length", type=int, default=10, help="mask_length")
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-prob",
|
||||
type=float,
|
||||
default=0.65,
|
||||
help="probability of replacing a token with mask",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-selection",
|
||||
type=str,
|
||||
choices=["static", "uniform", "normal", "poisson"],
|
||||
default="static",
|
||||
help="how to choose mask length",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-other",
|
||||
type=float,
|
||||
default=0,
|
||||
help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-mask-overlap",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="whether to allow masks to overlap",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-min-space",
|
||||
type=int,
|
||||
default=1,
|
||||
help="min space between spans (if no overlap is enabled)",
|
||||
)
|
||||
|
||||
# channel masking
|
||||
parser.add_argument(
|
||||
"--mask-channel-length",
|
||||
type=int,
|
||||
default=10,
|
||||
help="length of the mask for features (channels)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-channel-prob",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="probability of replacing a feature with 0",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-channel-selection",
|
||||
type=str,
|
||||
choices=["static", "uniform", "normal", "poisson"],
|
||||
default="static",
|
||||
help="how to choose mask length for channel masking",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-channel-other",
|
||||
type=float,
|
||||
default=0,
|
||||
help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-mask-channel-overlap",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="whether to allow channel masks to overlap",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-channel-min-space",
|
||||
type=int,
|
||||
default=1,
|
||||
help="min space between spans (if no overlap is enabled)",
|
||||
)
|
||||
|
||||
# positional embeddings
|
||||
parser.add_argument(
|
||||
"--conv-pos",
|
||||
type=int,
|
||||
default=128,
|
||||
help="number of filters for convolutional positional embeddings",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conv-pos-groups",
|
||||
type=int,
|
||||
default=16,
|
||||
help="number of groups for convolutional positional embedding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conv-pos-batch-norm",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use batch norm instead of weight norm in conv_pos (for bf16 models)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--latent-temp",
|
||||
type=float,
|
||||
nargs="*",
|
||||
default=[2, 0.5, 0.999995],
|
||||
help="legacy (to be removed)",
|
||||
)
|
||||
|
||||
# loss computation
|
||||
parser.add_argument(
|
||||
"--skip-masked",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="skip computing losses over masked frames",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip-nomask",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="skip computing losses over unmasked frames",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-activations",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="recompute activations and save memory for extra compute",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pred-masked-weight",
|
||||
type=float,
|
||||
default=1,
|
||||
help="weight for masked part in ssl loss",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pred-nomask-weight",
|
||||
type=float,
|
||||
default=0,
|
||||
help="weight for masked part in ssl loss",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--loss-weights",
|
||||
type=float,
|
||||
nargs="*",
|
||||
default=[10],
|
||||
help="weight for masked part in ssl loss",
|
||||
)
|
||||
|
||||
# FP16 optimization
|
||||
parser.add_argument(
|
||||
"--required-seq-len-multiple",
|
||||
type=int,
|
||||
default=2,
|
||||
help="pad the input to encoder such that the sequence length is divisible by multiple",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attn-type", type=str, default="", help="if espnet use ESPNET MHA"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-enc-type",
|
||||
type=str,
|
||||
default="abs",
|
||||
help="Positional encoding type to use in conformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-classes",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[504],
|
||||
help="""num class, a little larger than the number of cluster,
|
||||
the largest is for padding,
|
||||
and the value should be the multiple of 4, for faster computation""",
|
||||
)
|
||||
|
||||
|
||||
class HubertModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
|
||||
self.embed = feature_enc_layers[-1][0]
|
||||
|
||||
self.feature_extractor = ConvFeatureExtractionModel(
|
||||
conv_layers=feature_enc_layers,
|
||||
dropout=0.0,
|
||||
mode=cfg.extractor_mode,
|
||||
conv_bias=cfg.conv_bias,
|
||||
)
|
||||
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
||||
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate
|
||||
|
||||
self.post_extract_proj = (
|
||||
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
||||
if self.embed != cfg.encoder_embed_dim
|
||||
else None
|
||||
)
|
||||
|
||||
self.mask_prob = cfg.mask_prob
|
||||
self.mask_selection = cfg.mask_selection
|
||||
self.mask_other = cfg.mask_other
|
||||
self.mask_length = cfg.mask_length
|
||||
self.no_mask_overlap = cfg.no_mask_overlap
|
||||
self.mask_min_space = cfg.mask_min_space
|
||||
|
||||
self.mask_channel_prob = cfg.mask_channel_prob
|
||||
self.mask_channel_selection = cfg.mask_channel_selection
|
||||
self.mask_channel_other = cfg.mask_channel_other
|
||||
self.mask_channel_length = cfg.mask_channel_length
|
||||
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
||||
self.mask_channel_min_space = cfg.mask_channel_min_space
|
||||
|
||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
||||
|
||||
self.feature_grad_mult = cfg.feature_grad_mult
|
||||
self.logit_temp = cfg.logit_temp
|
||||
self.skip_masked = cfg.skip_masked
|
||||
self.skip_nomask = cfg.skip_nomask
|
||||
|
||||
self.mask_emb = nn.Parameter(
|
||||
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
||||
)
|
||||
|
||||
self.encoder = TransformerEncoder(cfg)
|
||||
self.layer_norm = LayerNorm(self.embed)
|
||||
|
||||
self.untie_final_proj = cfg.untie_final_proj
|
||||
self.final_proj = nn.Linear(cfg.encoder_embed_dim, sum(cfg.num_classes))
|
||||
|
||||
# modules below are not needed during fine-tuning
|
||||
self.num_classes = cfg.num_classes
|
||||
self.pred_masked_weight = cfg.pred_masked_weight
|
||||
self.pred_nomask_weight = cfg.pred_nomask_weight
|
||||
self.loss_weights = cfg.loss_weights
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
||||
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
return state_dict
|
||||
|
||||
def apply_mask(self, x, padding_mask, target_list):
|
||||
B, T, C = x.shape
|
||||
if self.mask_prob > 0:
|
||||
mask_indices = compute_mask_indices(
|
||||
(B, T),
|
||||
padding_mask,
|
||||
self.mask_prob,
|
||||
self.mask_length,
|
||||
self.mask_selection,
|
||||
self.mask_other,
|
||||
min_masks=2,
|
||||
no_overlap=self.no_mask_overlap,
|
||||
min_space=self.mask_min_space,
|
||||
)
|
||||
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
||||
x[mask_indices] = self.mask_emb.to(x.dtype)
|
||||
else:
|
||||
mask_indices = None
|
||||
|
||||
if self.mask_channel_prob > 0:
|
||||
mask_channel_indices = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
self.mask_channel_prob,
|
||||
self.mask_channel_length,
|
||||
self.mask_channel_selection,
|
||||
self.mask_channel_other,
|
||||
no_overlap=self.no_mask_channel_overlap,
|
||||
min_space=self.mask_channel_min_space,
|
||||
)
|
||||
mask_channel_indices = (
|
||||
torch.from_numpy(mask_channel_indices)
|
||||
.to(x.device)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, T, -1)
|
||||
)
|
||||
x[mask_channel_indices] = 0
|
||||
|
||||
return x, mask_indices
|
||||
|
||||
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
|
||||
if self.feature_grad_mult > 0:
|
||||
features = self.feature_extractor(source)
|
||||
if self.feature_grad_mult != 1.0:
|
||||
features = GradMultiply.apply(features, self.feature_grad_mult)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
features = self.feature_extractor(source)
|
||||
return features
|
||||
|
||||
def forward_targets(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
target_list: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Trim features to ensure labels exist and then get aligned labels
|
||||
feat_tsz = features.size(2)
|
||||
targ_tsz = min([t.size(1) for t in target_list])
|
||||
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
||||
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
||||
features = features[..., :feat_tsz]
|
||||
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
||||
target_list = [t[:, target_inds.long()] for t in target_list]
|
||||
return features, target_list
|
||||
|
||||
def forward_padding_mask(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
extra = padding_mask.size(1) % features.size(1)
|
||||
if extra > 0:
|
||||
padding_mask = padding_mask[:, :-extra]
|
||||
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
||||
padding_mask = padding_mask.all(-1)
|
||||
return padding_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
target_list: Optional[List[torch.Tensor]] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
mask: bool = True,
|
||||
features_only: bool = False,
|
||||
output_layer: Optional[int] = None,
|
||||
):
|
||||
"""output layer is 1-based"""
|
||||
features = self.forward_features(source)
|
||||
if target_list is not None:
|
||||
features, target_list = self.forward_targets(features, target_list)
|
||||
|
||||
features_pen = features.float().pow(2).mean()
|
||||
|
||||
features = features.transpose(1, 2)
|
||||
features = self.layer_norm(features)
|
||||
unmasked_features = features.clone()
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||
|
||||
if self.post_extract_proj is not None:
|
||||
features = self.post_extract_proj(features)
|
||||
|
||||
features = self.dropout_input(features)
|
||||
unmasked_features = self.dropout_features(unmasked_features)
|
||||
|
||||
if mask:
|
||||
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
|
||||
else:
|
||||
x = features
|
||||
mask_indices = None
|
||||
|
||||
# feature: (B, T, D), float
|
||||
# target: (B, T), long
|
||||
# x: (B, T, D), float
|
||||
# padding_mask: (B, T), bool
|
||||
# mask_indices: (B, T), bool
|
||||
x, _ = self.encoder(
|
||||
x,
|
||||
padding_mask=padding_mask,
|
||||
layer=None if output_layer is None else output_layer - 1,
|
||||
)
|
||||
|
||||
if features_only:
|
||||
return {"x": x, "padding_mask": padding_mask, "features": features}
|
||||
|
||||
if not self.skip_masked:
|
||||
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
||||
proj_x_m = self.final_proj(x[masked_indices])
|
||||
proj_x_m /= self.logit_temp
|
||||
logit_m_list = [proj_x_m for _ in range(len(target_list))]
|
||||
else:
|
||||
logit_m_list = [None for _ in target_list]
|
||||
|
||||
if not self.skip_nomask:
|
||||
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
|
||||
proj_x_u = self.final_proj(x[nomask_indices])
|
||||
proj_x_u /= self.logit_temp
|
||||
logit_u_list = [proj_x_u for _ in range(len(target_list))]
|
||||
else:
|
||||
logit_u_list = [None for _ in target_list]
|
||||
|
||||
# result = {
|
||||
# "logit_m_list": logit_m_list,
|
||||
# "logit_u_list": logit_u_list,
|
||||
# "padding_mask": padding_mask,
|
||||
# "features_pen": features_pen,
|
||||
# }
|
||||
targ_m_list = target_list[0][masked_indices]
|
||||
targ_m_list = targ_m_list.long()
|
||||
targ_m_list = [targ_m_list for _ in range(len(target_list))]
|
||||
|
||||
targ_u_list = target_list[0][nomask_indices]
|
||||
targ_u_list = targ_u_list.long()
|
||||
targ_u_list = [targ_u_list for _ in range(len(target_list))]
|
||||
return self.compute_loss(
|
||||
logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
|
||||
)
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
mask: bool = False,
|
||||
ret_conv: bool = False,
|
||||
output_layer: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
res = self.forward(
|
||||
source,
|
||||
padding_mask=padding_mask,
|
||||
mask=mask,
|
||||
features_only=True,
|
||||
output_layer=output_layer,
|
||||
)
|
||||
feature = res["features"] if ret_conv else res["x"]
|
||||
return feature, res["padding_mask"]
|
||||
|
||||
def get_logits(self, net_output, is_masked=True):
|
||||
if is_masked:
|
||||
logits_list = net_output["logit_m_list"]
|
||||
else:
|
||||
logits_list = net_output["logit_u_list"]
|
||||
logits_list = [x.float() for x in logits_list if x is not None]
|
||||
return logits_list
|
||||
|
||||
def get_targets(self, net_output, is_masked=True):
|
||||
logits_list = self.get_logits(net_output, is_masked)
|
||||
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
|
||||
return targets_list
|
||||
|
||||
def get_extra_losses(self, net_output):
|
||||
extra_losses = []
|
||||
names = []
|
||||
|
||||
if "features_pen" in net_output:
|
||||
extra_losses.append(net_output["features_pen"])
|
||||
names.append("features_pen")
|
||||
|
||||
return extra_losses, names
|
||||
|
||||
def remove_pretraining_modules(self):
|
||||
self.final_proj = None
|
||||
|
||||
def compute_loss(
|
||||
self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
|
||||
):
|
||||
loss = 0.0
|
||||
sample_size = 0
|
||||
logging_output = {}
|
||||
reduce = True
|
||||
reduction = "sum" if reduce else "none"
|
||||
|
||||
loss_m_list = []
|
||||
logp_m_list = [x.float() for x in logit_m_list if x is not None]
|
||||
logp_m_list = torch.cat(logp_m_list)
|
||||
targ_m_list = torch.cat(targ_m_list)
|
||||
|
||||
loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction)
|
||||
loss_m_list.append(loss_m)
|
||||
logging_output[f"loss_m_0"] = loss_m.detach().item()
|
||||
|
||||
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
|
||||
if self.pred_masked_weight > 0:
|
||||
loss += self.pred_masked_weight * sum(loss_m_list)
|
||||
sample_size += len(targ_m_list)
|
||||
|
||||
loss_u_list = []
|
||||
logp_u_list = [x.float() for x in logit_u_list if x is not None]
|
||||
logp_u_list = torch.cat(logp_u_list)
|
||||
targ_u_list = torch.cat(targ_u_list)
|
||||
|
||||
loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction)
|
||||
loss_u_list.append(loss_u)
|
||||
logging_output[f"loss_u_0"] = loss_u.detach().item()
|
||||
|
||||
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
|
||||
if self.pred_nomask_weight > 0:
|
||||
loss += self.pred_nomask_weight * sum(loss_u_list)
|
||||
sample_size += len(targ_u_list)
|
||||
|
||||
if self.loss_weights is not None:
|
||||
extra_losses = []
|
||||
names = []
|
||||
extra_losses.append(features_pen)
|
||||
names.append("features_pen")
|
||||
if torch.is_tensor(extra_losses):
|
||||
extra_losses = [extra_losses]
|
||||
names = [names]
|
||||
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
||||
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
||||
assert len(extra_losses) == len(
|
||||
self.loss_weights
|
||||
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
||||
for p, n, coef in zip(extra_losses, names, self.loss_weights):
|
||||
if coef != 0 and p is not None:
|
||||
p = coef * p.float() * sample_size
|
||||
loss += p
|
||||
logging_output[f"loss_{n}"] = p.item()
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.item() if reduce else loss,
|
||||
**logging_output,
|
||||
}
|
||||
|
||||
# for lk in self.log_keys:
|
||||
# if lk in net_output:
|
||||
# logging_output[lk] = float((net_output[lk]))
|
||||
|
||||
def compute_correct(logits, target):
|
||||
if logits.numel() == 0:
|
||||
return 0, 0
|
||||
else:
|
||||
assert logits.dim() > 1, logits.shape
|
||||
max = logits.argmax(-1) == target
|
||||
min = logits.argmin(-1) == target
|
||||
both = max & min
|
||||
corr = max.long().sum().item() - both.long().sum().item()
|
||||
count = max.numel()
|
||||
return corr, count
|
||||
|
||||
with torch.no_grad():
|
||||
corr_m, count_m = compute_correct(logp_m_list, targ_m_list)
|
||||
logging_output[f"correct_m_0"] = corr_m
|
||||
logging_output[f"count_m_0"] = count_m
|
||||
|
||||
corr_u, count_u = compute_correct(logp_u_list, targ_u_list)
|
||||
logging_output[f"correct_u_0"] = corr_u
|
||||
logging_output[f"count_u_0"] = count_u
|
||||
|
||||
return loss, sample_size, logging_output
|
1
egs/librispeech/SSL/hubert/joiner.py
Symbolic link
1
egs/librispeech/SSL/hubert/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/joiner.py
|
344
egs/librispeech/SSL/hubert/model.py
Normal file
344
egs/librispeech/SSL/hubert/model.py
Normal file
@ -0,0 +1,344 @@
|
||||
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Zengwei Yao,
|
||||
# Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder,
|
||||
decoder: Optional[nn.Module] = None,
|
||||
joiner: Optional[nn.Module] = None,
|
||||
encoder_dim: int = 768,
|
||||
decoder_dim: int = 512,
|
||||
vocab_size: int = 500,
|
||||
use_transducer: bool = True,
|
||||
use_ctc: bool = False,
|
||||
):
|
||||
"""A joint CTC & Transducer ASR model.
|
||||
|
||||
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
|
||||
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
|
||||
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
|
||||
|
||||
Args:
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
inputs: `x` of (N, T, encoder_dim).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
It is used when use_transducer is True.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
It is used when use_transducer is True.
|
||||
use_transducer:
|
||||
Whether use transducer head. Default: True.
|
||||
use_ctc:
|
||||
Whether use CTC head. Default: False.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
use_transducer or use_ctc
|
||||
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
|
||||
|
||||
self.encoder = encoder
|
||||
|
||||
self.use_transducer = use_transducer
|
||||
if use_transducer:
|
||||
# Modules for Transducer head
|
||||
assert decoder is not None
|
||||
assert hasattr(decoder, "blank_id")
|
||||
assert joiner is not None
|
||||
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim, vocab_size, initial_scale=0.25
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(
|
||||
decoder_dim, vocab_size, initial_scale=0.25
|
||||
)
|
||||
else:
|
||||
assert decoder is None
|
||||
assert joiner is None
|
||||
|
||||
self.use_ctc = use_ctc
|
||||
if use_ctc:
|
||||
# Modules for CTC head
|
||||
self.ctc_output = nn.Sequential(
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(encoder_dim, vocab_size),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
def forward_encoder(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute encoder outputs.
|
||||
Args:
|
||||
x:
|
||||
A 2-D tensor of shape (N, T).
|
||||
|
||||
Returns:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
"""
|
||||
if padding_mask is None:
|
||||
padding_mask = torch.zeros_like(x, dtype=torch.bool)
|
||||
|
||||
encoder_out, padding_mask = self.encoder.extract_features(
|
||||
source=x,
|
||||
padding_mask=padding_mask,
|
||||
mask=self.encoder.training,
|
||||
)
|
||||
encoder_out_lens = torch.sum(~padding_mask, dim=1)
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def forward_ctc(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
targets:
|
||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||
to be un-padded and concatenated within 1 dimension.
|
||||
"""
|
||||
# Compute CTC log-prob
|
||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||
targets=targets,
|
||||
input_lengths=encoder_out_lens,
|
||||
target_lengths=target_lengths,
|
||||
reduction="sum",
|
||||
)
|
||||
return ctc_loss
|
||||
|
||||
def forward_transducer(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
y_lens: torch.Tensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute Transducer loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
"""
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(encoder_out.size(0), 4),
|
||||
dtype=torch.int64,
|
||||
device=encoder_out.device,
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = encoder_out_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return simple_loss, pruned_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 2-D tensor of shape (N, T).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer losses and CTC loss,
|
||||
in form of (simple_loss, pruned_loss, ctc_loss)
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 2, x.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == y.dim0, (x.shape, y.dim0)
|
||||
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask)
|
||||
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
if self.use_transducer:
|
||||
# Compute transducer loss
|
||||
simple_loss, pruned_loss = self.forward_transducer(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
y=y.to(x.device),
|
||||
y_lens=y_lens,
|
||||
prune_range=prune_range,
|
||||
am_scale=am_scale,
|
||||
lm_scale=lm_scale,
|
||||
)
|
||||
else:
|
||||
simple_loss = torch.empty(0)
|
||||
pruned_loss = torch.empty(0)
|
||||
|
||||
if self.use_ctc:
|
||||
# Compute CTC loss
|
||||
targets = y.values
|
||||
ctc_loss = self.forward_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
)
|
||||
else:
|
||||
ctc_loss = torch.empty(0)
|
||||
|
||||
return simple_loss, pruned_loss, ctc_loss, encoder_out_lens
|
1
egs/librispeech/SSL/hubert/optim.py
Symbolic link
1
egs/librispeech/SSL/hubert/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/optim.py
|
1082
egs/librispeech/SSL/hubert/pretrain.py
Normal file
1082
egs/librispeech/SSL/hubert/pretrain.py
Normal file
File diff suppressed because it is too large
Load Diff
1082
egs/librispeech/SSL/hubert/pretrain_ce.py
Normal file
1082
egs/librispeech/SSL/hubert/pretrain_ce.py
Normal file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/SSL/hubert/scaling.py
Symbolic link
1
egs/librispeech/SSL/hubert/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/scaling.py
|
341
egs/librispeech/SSL/hubert/ssl_datamodule.py
Normal file
341
egs/librispeech/SSL/hubert/ssl_datamodule.py
Normal file
@ -0,0 +1,341 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from dataset import HubertDataset
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriSpeechDataModule:
|
||||
"""
|
||||
DataModule for SSL experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in SSL
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
|
||||
This class should be derived for specific corpora used in SSL tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="SSL data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled use 960h LibriSpeech. " "Otherwise, use 100h subset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/kmeans"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=float,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--do-normalize",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="whether to normalize the data",
|
||||
)
|
||||
group.add_argument(
|
||||
"--random-crop",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="always crop from the beginning if false",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
max_sample_size: Optional[int] = None,
|
||||
sample_rate: float = 16000,
|
||||
label_rate: float = 50,
|
||||
random_crop: bool = True,
|
||||
pad_audio: bool = False,
|
||||
num_classes: list = [504],
|
||||
do_normalize: bool = True,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = HubertDataset(
|
||||
max_sample_size=max_sample_size,
|
||||
sample_rate=sample_rate,
|
||||
label_rate=label_rate,
|
||||
random_crop=random_crop,
|
||||
pad_audio=pad_audio,
|
||||
num_classes=num_classes,
|
||||
do_normalize=do_normalize,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(
|
||||
self,
|
||||
cuts_valid: CutSet,
|
||||
max_sample_size: Optional[int] = None,
|
||||
sample_rate: float = 16000,
|
||||
label_rate: float = 50,
|
||||
random_crop: bool = True,
|
||||
pad_audio: bool = False,
|
||||
num_classes: list = [504],
|
||||
do_normalize: bool = True,
|
||||
) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
validate = HubertDataset(
|
||||
max_sample_size=max_sample_size,
|
||||
sample_rate=sample_rate,
|
||||
label_rate=label_rate,
|
||||
random_crop=random_crop,
|
||||
pad_audio=pad_audio,
|
||||
num_classes=num_classes,
|
||||
do_normalize=do_normalize,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(
|
||||
self,
|
||||
cuts: CutSet,
|
||||
sample_rate: float = 16000,
|
||||
label_rate: float = 50,
|
||||
random_crop: bool = True,
|
||||
pad_audio: bool = False,
|
||||
num_classes: list = [504],
|
||||
do_normalize: bool = True,
|
||||
) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = HubertDataset(
|
||||
sample_rate=sample_rate,
|
||||
label_rate=label_rate,
|
||||
random_crop=random_crop,
|
||||
pad_audio=pad_audio,
|
||||
num_classes=num_classes,
|
||||
do_normalize=do_normalize,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_360_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-360 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_other_500_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-other-500 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_all_shuf_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled train-clean-100, \
|
||||
train-clean-360 and train-other-500 cuts"
|
||||
)
|
||||
train_clean_100_cuts = self.train_clean_100_cuts()
|
||||
train_clean_360_cuts = self.train_clean_360_cuts()
|
||||
train_other_500_cuts = self.train_other_500_cuts()
|
||||
return CutSet.mux(
|
||||
train_clean_100_cuts,
|
||||
train_clean_360_cuts,
|
||||
train_other_500_cuts,
|
||||
weights=[
|
||||
28539, # len(train_clean_100_cuts)
|
||||
104014, # len(train_clean_360_cuts)
|
||||
148688, # len(train_other_500_cuts)
|
||||
],
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||
)
|
338
egs/librispeech/SSL/hubert/utils.py
Normal file
338
egs/librispeech/SSL/hubert/utils.py
Normal file
@ -0,0 +1,338 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import math
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def relu_squared(x: torch.Tensor):
|
||||
return F.relu(x).pow(2)
|
||||
|
||||
|
||||
def gelu_accurate(x):
|
||||
if not hasattr(gelu_accurate, "_a"):
|
||||
gelu_accurate._a = math.sqrt(2 / math.pi)
|
||||
return (
|
||||
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
||||
)
|
||||
|
||||
|
||||
def is_xla_tensor(tensor):
|
||||
return torch.is_tensor(tensor) and tensor.device.type == "xla"
|
||||
|
||||
|
||||
def index_put(tensor, indices, value):
|
||||
if is_xla_tensor(tensor):
|
||||
for _ in range(indices.dim(), tensor.dim()):
|
||||
indices = indices.unsqueeze(-1)
|
||||
if indices.size(-1) < tensor.size(-1):
|
||||
indices = indices.expand_as(tensor)
|
||||
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
|
||||
else:
|
||||
tensor[indices] = value
|
||||
return tensor
|
||||
|
||||
|
||||
def pad_to_multiple(x, multiple, dim=-1, value=0):
|
||||
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
|
||||
if x is None:
|
||||
return None, 0
|
||||
tsz = x.size(dim)
|
||||
m = tsz / multiple
|
||||
remainder = math.ceil(m) * multiple - tsz
|
||||
if m.is_integer():
|
||||
return x, 0
|
||||
pad_offset = (0,) * (-1 - dim) * 2
|
||||
|
||||
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
|
||||
|
||||
|
||||
def gelu(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(x.float()).type_as(x)
|
||||
|
||||
|
||||
def get_activation_fn(activation: str) -> Callable:
|
||||
"""Returns the activation function corresponding to `activation`"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "relu_squared":
|
||||
return relu_squared
|
||||
elif activation == "gelu":
|
||||
return gelu
|
||||
elif activation == "gelu_fast":
|
||||
return gelu_accurate
|
||||
elif activation == "gelu_accurate":
|
||||
return gelu_accurate
|
||||
elif activation == "tanh":
|
||||
return torch.tanh
|
||||
elif activation == "linear":
|
||||
return lambda x: x
|
||||
elif activation == "swish":
|
||||
return torch.nn.SiLU
|
||||
else:
|
||||
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||
|
||||
|
||||
class SamePad(nn.Module):
|
||||
def __init__(self, kernel_size, causal=False):
|
||||
super().__init__()
|
||||
if causal:
|
||||
self.remove = kernel_size - 1
|
||||
else:
|
||||
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||
|
||||
def forward(self, x):
|
||||
if self.remove > 0:
|
||||
x = x[:, :, : -self.remove]
|
||||
return x
|
||||
|
||||
|
||||
class SamePad2d(nn.Module):
|
||||
def __init__(self, kernel_size):
|
||||
super().__init__()
|
||||
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x.size()) == 4
|
||||
if self.remove > 0:
|
||||
x = x[:, :, : -self.remove, : -self.remove]
|
||||
return x
|
||||
|
||||
|
||||
class TransposeLast(nn.Module):
|
||||
def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
|
||||
super().__init__()
|
||||
self.deconstruct_idx = deconstruct_idx
|
||||
self.tranpose_dim = tranpose_dim
|
||||
|
||||
def forward(self, x):
|
||||
if self.deconstruct_idx is not None:
|
||||
x = x[self.deconstruct_idx]
|
||||
return x.transpose(self.tranpose_dim, -1)
|
||||
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
|
||||
|
||||
has_fused_layernorm = True
|
||||
|
||||
class FusedLayerNorm(_FusedLayerNorm):
|
||||
@torch.jit.unused
|
||||
def forward(self, x):
|
||||
if not x.is_cuda:
|
||||
return super().forward(x)
|
||||
else:
|
||||
with torch.cuda.device(x.device):
|
||||
return super().forward(x)
|
||||
|
||||
except ImportError:
|
||||
has_fused_layernorm = False
|
||||
|
||||
|
||||
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
export = True
|
||||
if not export and torch.cuda.is_available() and has_fused_layernorm:
|
||||
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
||||
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
||||
|
||||
|
||||
class Fp32LayerNorm(nn.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
output = F.layer_norm(
|
||||
input.float(),
|
||||
self.normalized_shape,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
||||
|
||||
|
||||
class Fp32GroupNorm(nn.GroupNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
output = F.group_norm(
|
||||
input.float(),
|
||||
self.num_groups,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
||||
|
||||
|
||||
def softmax(x, dim: int, onnx_trace: bool = False):
|
||||
if onnx_trace:
|
||||
return F.softmax(x.float(), dim=dim)
|
||||
else:
|
||||
return F.softmax(x, dim=dim, dtype=torch.float32)
|
||||
|
||||
|
||||
def quant_noise(module, p, block_size):
|
||||
"""
|
||||
Wraps modules and applies quantization noise to the weights for
|
||||
subsequent quantization with Iterative Product Quantization as
|
||||
described in "Training with Quantization Noise for Extreme Model Compression"
|
||||
|
||||
Args:
|
||||
- module: nn.Module
|
||||
- p: amount of Quantization Noise
|
||||
- block_size: size of the blocks for subsequent quantization with iPQ
|
||||
|
||||
Remarks:
|
||||
- Module weights must have the right sizes wrt the block size
|
||||
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
||||
- For more detail on how to quantize by blocks with convolutional weights,
|
||||
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
||||
- We implement the simplest form of noise here as stated in the paper
|
||||
which consists in randomly dropping blocks
|
||||
"""
|
||||
|
||||
# if no quantization noise, don't register hook
|
||||
if p <= 0:
|
||||
return module
|
||||
|
||||
# supported modules
|
||||
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
||||
|
||||
# test whether module.weight has the right sizes wrt block_size
|
||||
is_conv = module.weight.ndim == 4
|
||||
|
||||
# 2D matrix
|
||||
if not is_conv:
|
||||
assert (
|
||||
module.weight.size(1) % block_size == 0
|
||||
), "Input features must be a multiple of block sizes"
|
||||
|
||||
# 4D matrix
|
||||
else:
|
||||
# 1x1 convolutions
|
||||
if module.kernel_size == (1, 1):
|
||||
assert (
|
||||
module.in_channels % block_size == 0
|
||||
), "Input channels must be a multiple of block sizes"
|
||||
# regular convolutions
|
||||
else:
|
||||
k = module.kernel_size[0] * module.kernel_size[1]
|
||||
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
||||
|
||||
def _forward_pre_hook(mod, input):
|
||||
# no noise for evaluation
|
||||
if mod.training:
|
||||
if not is_conv:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_features = weight.size(1)
|
||||
out_features = weight.size(0)
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
mask = torch.zeros(
|
||||
in_features // block_size * out_features,
|
||||
device=weight.device,
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
||||
|
||||
else:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_channels = mod.in_channels
|
||||
out_channels = mod.out_channels
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
if mod.kernel_size == (1, 1):
|
||||
mask = torch.zeros(
|
||||
int(in_channels // block_size * out_channels),
|
||||
device=weight.device,
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
||||
else:
|
||||
mask = torch.zeros(
|
||||
weight.size(0), weight.size(1), device=weight.device
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = (
|
||||
mask.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
||||
)
|
||||
|
||||
# scale weights and apply mask
|
||||
mask = mask.to(
|
||||
torch.bool
|
||||
) # x.bool() is not currently supported in TorchScript
|
||||
s = 1 / (1 - p)
|
||||
mod.weight.data = s * weight.masked_fill(mask, 0)
|
||||
|
||||
module.register_forward_pre_hook(_forward_pre_hook)
|
||||
return module
|
||||
|
||||
|
||||
class FairseqDropout(nn.Module):
|
||||
def __init__(self, p, module_name=None):
|
||||
super().__init__()
|
||||
self.p = p
|
||||
self.module_name = module_name
|
||||
self.apply_during_inference = False
|
||||
|
||||
def forward(self, x, inplace: bool = False):
|
||||
if self.p > 0 and (self.training or self.apply_during_inference):
|
||||
return F.dropout(x, p=self.p, training=True, inplace=inplace)
|
||||
else:
|
||||
return x
|
||||
|
||||
def make_generation_fast_(
|
||||
self,
|
||||
name: str,
|
||||
retain_dropout: bool = False,
|
||||
retain_dropout_modules: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
):
|
||||
if retain_dropout:
|
||||
if retain_dropout_modules is not None and self.module_name is None:
|
||||
pass
|
||||
elif (
|
||||
retain_dropout_modules is None # if None, apply to all modules
|
||||
or self.module_name in retain_dropout_modules
|
||||
):
|
||||
self.apply_during_inference = True
|
||||
|
||||
|
||||
class GradMultiply(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, scale):
|
||||
ctx.scale = scale
|
||||
res = x.new(x)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return grad * ctx.scale, None
|
593
egs/librispeech/SSL/hubert/wav2vec2_module.py
Normal file
593
egs/librispeech/SSL/hubert/wav2vec2_module.py
Normal file
@ -0,0 +1,593 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from attention_module import MultiheadAttention, init_bert_params
|
||||
from utils import (
|
||||
Fp32GroupNorm,
|
||||
Fp32LayerNorm,
|
||||
LayerNorm,
|
||||
SamePad,
|
||||
TransposeLast,
|
||||
get_activation_fn,
|
||||
index_put,
|
||||
pad_to_multiple,
|
||||
)
|
||||
|
||||
|
||||
class ConvFeatureExtractionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conv_layers: List[Tuple[int, int, int]],
|
||||
dropout: float = 0.0,
|
||||
mode: str = "default",
|
||||
conv_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert mode in {"default", "layer_norm"}
|
||||
|
||||
def block(
|
||||
n_in,
|
||||
n_out,
|
||||
k,
|
||||
stride,
|
||||
is_layer_norm=False,
|
||||
is_group_norm=False,
|
||||
conv_bias=False,
|
||||
):
|
||||
def make_conv():
|
||||
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
||||
nn.init.kaiming_normal_(conv.weight)
|
||||
return conv
|
||||
|
||||
assert (
|
||||
is_layer_norm and is_group_norm
|
||||
) == False, "layer norm and group norm are exclusive"
|
||||
|
||||
if is_layer_norm:
|
||||
return nn.Sequential(
|
||||
make_conv(),
|
||||
nn.Dropout(p=dropout),
|
||||
nn.Sequential(
|
||||
TransposeLast(),
|
||||
Fp32LayerNorm(dim, elementwise_affine=True),
|
||||
TransposeLast(),
|
||||
),
|
||||
nn.GELU(),
|
||||
)
|
||||
elif is_group_norm:
|
||||
return nn.Sequential(
|
||||
make_conv(),
|
||||
nn.Dropout(p=dropout),
|
||||
Fp32GroupNorm(dim, dim, affine=True),
|
||||
nn.GELU(),
|
||||
)
|
||||
else:
|
||||
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
||||
|
||||
in_d = 1
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for i, cl in enumerate(conv_layers):
|
||||
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
||||
(dim, k, stride) = cl
|
||||
|
||||
self.conv_layers.append(
|
||||
block(
|
||||
in_d,
|
||||
dim,
|
||||
k,
|
||||
stride,
|
||||
is_layer_norm=mode == "layer_norm",
|
||||
is_group_norm=mode == "default" and i == 0,
|
||||
conv_bias=conv_bias,
|
||||
)
|
||||
)
|
||||
in_d = dim
|
||||
|
||||
def forward(self, x):
|
||||
# BxT -> BxCxT
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
for conv in self.conv_layers:
|
||||
x = conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def make_conv_pos(e, k, g, is_batch_norm=False):
|
||||
pos_conv = nn.Conv1d(
|
||||
e,
|
||||
e,
|
||||
kernel_size=k,
|
||||
padding=k // 2,
|
||||
groups=g,
|
||||
)
|
||||
dropout = 0
|
||||
std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
|
||||
nn.init.normal_(pos_conv.weight, mean=0, std=std)
|
||||
nn.init.constant_(pos_conv.bias, 0)
|
||||
|
||||
if not is_batch_norm:
|
||||
pos_conv = nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2)
|
||||
pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU())
|
||||
else:
|
||||
batch_norm = nn.BatchNorm1d(e)
|
||||
pos_conv = nn.Sequential(batch_norm, pos_conv, SamePad(k), nn.GELU())
|
||||
|
||||
return pos_conv
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def build_encoder_layer(self, args, **kwargs):
|
||||
if args.layer_type == "transformer":
|
||||
layer = TransformerSentenceEncoderLayer(
|
||||
embedding_dim=self.embedding_dim,
|
||||
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||
num_attention_heads=args.encoder_attention_heads,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=args.attention_dropout,
|
||||
activation_dropout=args.activation_dropout,
|
||||
activation_fn=args.activation_fn,
|
||||
layer_norm_first=args.layer_norm_first,
|
||||
)
|
||||
elif args.layer_type == "trf_adp":
|
||||
use_adp = False
|
||||
if args.adp_trf_idx == "all":
|
||||
use_adp = True
|
||||
else:
|
||||
adp_trf_idx = list(
|
||||
range(*[int(g) for g in args.adp_trf_idx.split(":")])
|
||||
)
|
||||
if kwargs.get("layer_idx", None) in adp_trf_idx:
|
||||
use_adp = True
|
||||
if use_adp:
|
||||
layer = TransformerSentenceEncoderWithAdapterLayer(
|
||||
embedding_dim=self.embedding_dim,
|
||||
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||
num_attention_heads=args.encoder_attention_heads,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=args.attention_dropout,
|
||||
activation_dropout=args.activation_dropout,
|
||||
activation_fn=args.activation_fn,
|
||||
layer_norm_first=args.layer_norm_first,
|
||||
adapter_num=args.adp_num,
|
||||
adapter_dim=args.adp_dim,
|
||||
adapter_act_fn=args.adp_act_fn,
|
||||
)
|
||||
else:
|
||||
layer = TransformerSentenceEncoderLayer(
|
||||
embedding_dim=self.embedding_dim,
|
||||
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||
num_attention_heads=args.encoder_attention_heads,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=args.attention_dropout,
|
||||
activation_dropout=args.activation_dropout,
|
||||
activation_fn=args.activation_fn,
|
||||
layer_norm_first=args.layer_norm_first,
|
||||
)
|
||||
|
||||
# layer = fsdp_wrap(layer)
|
||||
# if args.checkpoint_activations:
|
||||
# layer = checkpoint_wrapper(layer)
|
||||
return layer
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
|
||||
self.dropout = args.dropout
|
||||
self.embedding_dim = args.encoder_embed_dim
|
||||
self.required_seq_len_multiple = args.required_seq_len_multiple
|
||||
|
||||
pos_conv_depth = getattr(args, "pos_conv_depth", 1)
|
||||
if pos_conv_depth > 1:
|
||||
num_layers = args.pos_conv_depth
|
||||
k = max(3, args.conv_pos // num_layers)
|
||||
|
||||
def make_conv_block(e, k, g, l):
|
||||
return nn.Sequential(
|
||||
*[
|
||||
nn.Sequential(
|
||||
nn.Conv1d(
|
||||
e,
|
||||
e,
|
||||
kernel_size=k,
|
||||
padding=k // 2,
|
||||
groups=g,
|
||||
),
|
||||
SamePad(k),
|
||||
TransposeLast(),
|
||||
LayerNorm(e, elementwise_affine=False),
|
||||
TransposeLast(),
|
||||
nn.GELU(),
|
||||
)
|
||||
for _ in range(l)
|
||||
]
|
||||
)
|
||||
|
||||
self.pos_conv = make_conv_block(
|
||||
self.embedding_dim, k, args.conv_pos_groups, num_layers
|
||||
)
|
||||
|
||||
else:
|
||||
self.pos_conv = make_conv_pos(
|
||||
self.embedding_dim,
|
||||
args.conv_pos,
|
||||
args.conv_pos_groups,
|
||||
is_batch_norm=args.conv_pos_batch_norm
|
||||
if hasattr(args, "conv_pos_batch_norm")
|
||||
else False,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
self.build_encoder_layer(args, layer_idx=ii)
|
||||
for ii in range(args.encoder_layers)
|
||||
]
|
||||
)
|
||||
self.layer_norm_first = args.layer_norm_first
|
||||
self.layer_norm = LayerNorm(self.embedding_dim)
|
||||
self.layerdrop = args.encoder_layerdrop
|
||||
|
||||
self.apply(init_bert_params)
|
||||
|
||||
def forward(self, x, padding_mask=None, layer=None, corpus_key=None):
|
||||
x, layer_results = self.extract_features(
|
||||
x, padding_mask, layer, corpus_key=corpus_key
|
||||
)
|
||||
|
||||
if self.layer_norm_first and layer is None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
x,
|
||||
padding_mask=None,
|
||||
tgt_layer=None,
|
||||
min_layer=0,
|
||||
corpus_key=None,
|
||||
):
|
||||
if padding_mask is not None:
|
||||
x = index_put(x, padding_mask, 0)
|
||||
|
||||
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||
x_conv = x_conv.transpose(1, 2)
|
||||
x = x + x_conv
|
||||
|
||||
if not self.layer_norm_first:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# pad to the sequence length dimension
|
||||
x, pad_length = pad_to_multiple(
|
||||
x, self.required_seq_len_multiple, dim=-2, value=0
|
||||
)
|
||||
if pad_length > 0 and padding_mask is None:
|
||||
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
|
||||
padding_mask[:, -pad_length:] = True
|
||||
else:
|
||||
padding_mask, _ = pad_to_multiple(
|
||||
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
layer_results = []
|
||||
r = None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
||||
if not self.training or (dropout_probability > self.layerdrop):
|
||||
layer_check = layer
|
||||
# if isinstance(layer, FullyShardedDataParallel):
|
||||
# layer_check = layer.unwrapped_module
|
||||
if (corpus_key is None) or (
|
||||
not isinstance(
|
||||
layer_check,
|
||||
(TransformerSentenceEncoderWithAdapterLayer,),
|
||||
)
|
||||
):
|
||||
x, (z, lr) = layer(
|
||||
x,
|
||||
self_attn_padding_mask=padding_mask,
|
||||
need_weights=False,
|
||||
)
|
||||
else:
|
||||
x, (z, lr) = layer(
|
||||
x,
|
||||
self_attn_padding_mask=padding_mask,
|
||||
need_weights=False,
|
||||
corpus_key=corpus_key,
|
||||
)
|
||||
if i >= min_layer:
|
||||
layer_results.append((x, z, lr))
|
||||
if i == tgt_layer:
|
||||
r = x
|
||||
break
|
||||
|
||||
if r is not None:
|
||||
x = r
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
# undo paddding
|
||||
if pad_length > 0:
|
||||
x = x[:, :-pad_length]
|
||||
|
||||
def undo_pad(a, b, c):
|
||||
return (
|
||||
a[:-pad_length],
|
||||
b[:-pad_length] if b is not None else b,
|
||||
c[:-pad_length],
|
||||
)
|
||||
|
||||
layer_results = [undo_pad(*u) for u in layer_results]
|
||||
|
||||
return x, layer_results
|
||||
|
||||
def max_positions(self):
|
||||
"""Maximum output length supported by the encoder."""
|
||||
return self.args.max_positions
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
||||
return state_dict
|
||||
|
||||
|
||||
class TransformerSentenceEncoderLayer(nn.Module):
|
||||
"""
|
||||
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
||||
models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: float = 768,
|
||||
ffn_embedding_dim: float = 3072,
|
||||
num_attention_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
attention_dropout: float = 0.1,
|
||||
activation_dropout: float = 0.1,
|
||||
activation_fn: str = "relu",
|
||||
layer_norm_first: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Initialize parameters
|
||||
self.embedding_dim = embedding_dim
|
||||
self.dropout = dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
|
||||
# Initialize blocks
|
||||
self.activation_fn = get_activation_fn(activation_fn)
|
||||
self.self_attn = MultiheadAttention(
|
||||
self.embedding_dim,
|
||||
num_attention_heads,
|
||||
dropout=attention_dropout,
|
||||
self_attention=True,
|
||||
)
|
||||
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(self.activation_dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.layer_norm_first = layer_norm_first
|
||||
|
||||
# layer norm associated with the self attention layer
|
||||
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
||||
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
||||
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
||||
|
||||
# layer norm associated with the position wise feed-forward NN
|
||||
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
self_attn_mask: torch.Tensor = None,
|
||||
self_attn_padding_mask: torch.Tensor = None,
|
||||
need_weights: bool = False,
|
||||
att_args=None,
|
||||
):
|
||||
"""
|
||||
LayerNorm is applied either before or after the self-attention/ffn
|
||||
modules similar to the original Transformer imlementation.
|
||||
"""
|
||||
residual = x
|
||||
|
||||
if self.layer_norm_first:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, attn = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
attn_mask=self_attn_mask,
|
||||
need_weights=False,
|
||||
)
|
||||
x = self.dropout1(x)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
|
||||
layer_result = x
|
||||
|
||||
x = self.dropout3(x)
|
||||
x = residual + x
|
||||
else:
|
||||
x, attn = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
need_weights=False,
|
||||
)
|
||||
|
||||
x = self.dropout1(x)
|
||||
x = residual + x
|
||||
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
|
||||
layer_result = x
|
||||
|
||||
x = self.dropout3(x)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x)
|
||||
|
||||
return x, (attn, layer_result)
|
||||
|
||||
|
||||
class AdapterFast(nn.Module):
|
||||
def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
|
||||
"""
|
||||
Implements adapter modules directly with 3D tensor weight as parameters
|
||||
and without using ModuleList orto speed up training throughput.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.adapter_num = adapter_num
|
||||
self.input_dim = input_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim))
|
||||
self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim))
|
||||
self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim))
|
||||
self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim))
|
||||
|
||||
self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim))
|
||||
self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim))
|
||||
self.act_fn = nn.Identity()
|
||||
if act_fn == "relu":
|
||||
self.act_fn = nn.ReLU()
|
||||
elif act_fn == "gelu":
|
||||
self.act_fn = nn.GELU()
|
||||
elif act_fn == "selu":
|
||||
self.act_fn = nn.SELU()
|
||||
else:
|
||||
raise ValueError(f"unsupported {act_fn}")
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
for ii in range(self.adapter_num):
|
||||
nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5))
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii])
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
nn.init.uniform_(self.b_a[ii], -bound, bound)
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
nn.init.uniform_(self.b_b[ii], -bound, bound)
|
||||
|
||||
nn.init.ones_(self.ln_W)
|
||||
nn.init.zeros_(self.ln_b)
|
||||
|
||||
def forward(self, x, adapter_id):
|
||||
ii = adapter_id
|
||||
h = x
|
||||
h = F.layer_norm(h, (self.input_dim,), self.ln_W[ii], self.ln_b[ii])
|
||||
h = F.linear(h, self.W_a[ii], self.b_a[ii])
|
||||
h = self.act_fn(h)
|
||||
h = F.linear(h, self.W_b[ii], self.b_b[ii])
|
||||
outputs = h
|
||||
return outputs
|
||||
|
||||
def extra_repr(self):
|
||||
return "adapter={}, input_dim={}, hidden_dim={}".format(
|
||||
self.adapter_num, self.input_dim, self.hidden_dim
|
||||
)
|
||||
|
||||
|
||||
class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer):
|
||||
"""
|
||||
Implements a Transformer Encoder Layer with adapters used in BERT/XLM style pre-trained
|
||||
models. An adapter module is added along with vanilla Transformer module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: float = 768,
|
||||
ffn_embedding_dim: float = 3072,
|
||||
num_attention_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
attention_dropout: float = 0.1,
|
||||
activation_dropout: float = 0.1,
|
||||
activation_fn: str = "relu",
|
||||
layer_norm_first: bool = False,
|
||||
adapter_num=201,
|
||||
adapter_dim=64,
|
||||
adapter_act_fn="relu",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=embedding_dim,
|
||||
ffn_embedding_dim=ffn_embedding_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
dropout=dropout,
|
||||
attention_dropout=attention_dropout,
|
||||
activation_dropout=activation_dropout,
|
||||
activation_fn=activation_fn,
|
||||
layer_norm_first=layer_norm_first,
|
||||
)
|
||||
|
||||
self.adapter_num = adapter_num
|
||||
self.adapter_dim = adapter_dim
|
||||
self.adapter_layer = AdapterFast(
|
||||
adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
self_attn_mask: torch.Tensor = None,
|
||||
self_attn_padding_mask: torch.Tensor = None,
|
||||
need_weights: bool = False,
|
||||
att_args=None,
|
||||
corpus_key=None,
|
||||
):
|
||||
x, (attn, layer_result) = super().forward(
|
||||
x=x,
|
||||
self_attn_mask=self_attn_mask,
|
||||
self_attn_padding_mask=self_attn_padding_mask,
|
||||
need_weights=need_weights,
|
||||
att_args=att_args,
|
||||
)
|
||||
assert corpus_key is not None
|
||||
assert len(set(corpus_key)) == 1, f"corpus_key items are not same {corpus_key}"
|
||||
y = self.adapter_layer(x, corpus_key[0])
|
||||
x = x + y
|
||||
return x, (attn, layer_result)
|
52
egs/librispeech/SSL/local/attach_kmeans_to_supervisions.py
Normal file
52
egs/librispeech/SSL/local/attach_kmeans_to_supervisions.py
Normal file
@ -0,0 +1,52 @@
|
||||
import os
|
||||
|
||||
import jsonlines
|
||||
from tqdm import tqdm
|
||||
|
||||
os.system(
|
||||
"cp /userhome/user/yfy62/librispeech_data/data4ssl/manifests/librispeech_*_dev-clean* ."
|
||||
)
|
||||
os.system(
|
||||
"cp /userhome/user/yfy62/librispeech_data/data4ssl/manifests/librispeech_*_train* ."
|
||||
)
|
||||
os.system("chmod -R 644 *.jsonl.gz")
|
||||
os.system("gunzip *.gz")
|
||||
|
||||
dataset_parts = (
|
||||
"dev-clean",
|
||||
"train-clean-100",
|
||||
"train-clean-360",
|
||||
"train-other-500",
|
||||
)
|
||||
|
||||
kmeans_dir = "/userhome/user/yangguanrou/data/k500"
|
||||
idx_dir = "/userhome/user/yangguanrou/data/shu"
|
||||
|
||||
kmeans = []
|
||||
idxs = []
|
||||
for part in ["train", "valid"]:
|
||||
with open(kmeans_dir + "/" + part + ".km", "r") as f:
|
||||
kmeans += f.read().splitlines()
|
||||
|
||||
with open(idx_dir + "/" + part + ".tsv", "r") as f:
|
||||
lines = f.read().splitlines()
|
||||
idxs += [
|
||||
line.split("\t", -1)[0].split("/", -1)[-1].replace(".flac", "")
|
||||
for line in lines
|
||||
if ".flac" in line
|
||||
]
|
||||
|
||||
idx2kmeans = {}
|
||||
for idx, km in zip(idxs, kmeans):
|
||||
idx2kmeans[idx] = km
|
||||
|
||||
for part in dataset_parts:
|
||||
with jsonlines.open(f"librispeech_supervisions_{part}.jsonl") as reader:
|
||||
with jsonlines.open(
|
||||
f"librispeech_supervisions_{part}_new.jsonl", mode="w"
|
||||
) as writer:
|
||||
for obj in tqdm(reader):
|
||||
obj["custom"] = {"kmeans": idx2kmeans[obj["id"]]}
|
||||
writer.write(obj)
|
||||
|
||||
os.system('for file in *_new.jsonl; do mv "$file" "${file%_new.jsonl}.jsonl"; done')
|
18
egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py
Normal file
18
egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py
Normal file
@ -0,0 +1,18 @@
|
||||
# simple script to convert a fairseq checkpoint into pytorch parameter state dict
|
||||
from argparse import ArgumentParser
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--src")
|
||||
parser.add_argument("--tgt")
|
||||
|
||||
args = parser.parse_args()
|
||||
src = args.src
|
||||
tgt = args.tgt
|
||||
|
||||
old_checkpoint = torch.load(src)
|
||||
new_checkpoint = OrderedDict()
|
||||
new_checkpoint["model"] = old_checkpoint["model"]
|
||||
torch.save(new_checkpoint, tgt)
|
259
egs/librispeech/SSL/local/prepare_char.py
Normal file
259
egs/librispeech/SSL/local/prepare_char.py
Normal file
@ -0,0 +1,259 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
This script takes as input `lang_dir`, which should contain::
|
||||
|
||||
- lang_dir/text,
|
||||
- lang_dir/words.txt
|
||||
|
||||
and generates the following files in the directory `lang_dir`:
|
||||
|
||||
- lexicon.txt
|
||||
- lexicon_disambig.txt
|
||||
- L.pt
|
||||
- L_disambig.pt
|
||||
- tokens.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from prepare_lang import (
|
||||
Lexicon,
|
||||
add_disambig_symbols,
|
||||
add_self_loops,
|
||||
write_lexicon,
|
||||
write_mapping,
|
||||
)
|
||||
|
||||
|
||||
def lexicon_to_fst_no_sil(
|
||||
lexicon: Lexicon,
|
||||
token2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
need_self_loops: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format).
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
token2id:
|
||||
A dict mapping tokens to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
need_self_loops:
|
||||
If True, add self-loop to states with non-epsilon output symbols
|
||||
on at least one arc out of the state. The input label for this
|
||||
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
loop_state = 0 # words enter and leave from here
|
||||
next_state = 1 # the next un-allocated state, will be incremented as we go
|
||||
|
||||
arcs = []
|
||||
|
||||
# The blank symbol <blk> is defined in local/train_bpe_model.py
|
||||
assert token2id["<blk>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
|
||||
for word, pieces in lexicon:
|
||||
assert len(pieces) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
|
||||
|
||||
for i in range(len(pieces) - 1):
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, next_state, pieces[i], w, 0])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last piece of this word
|
||||
i = len(pieces) - 1
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, loop_state, pieces[i], w, 0])
|
||||
|
||||
if need_self_loops:
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
||||
"""Check if all the given tokens are in token symbol table.
|
||||
|
||||
Args:
|
||||
token_sym_table:
|
||||
Token symbol table that contains all the valid tokens.
|
||||
tokens:
|
||||
A list of tokens.
|
||||
Returns:
|
||||
Return True if there is any token not in the token_sym_table,
|
||||
otherwise False.
|
||||
"""
|
||||
for tok in tokens:
|
||||
if tok not in token_sym_table:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
|
||||
"""Generate a lexicon from a word list and token_sym_table.
|
||||
|
||||
Args:
|
||||
token_sym_table:
|
||||
Token symbol table that mapping token to token ids.
|
||||
words:
|
||||
A list of strings representing words.
|
||||
Returns:
|
||||
Return a dict whose keys are words and values are the corresponding
|
||||
tokens.
|
||||
"""
|
||||
lexicon = []
|
||||
for word in words:
|
||||
chars = list(word.strip(" \t"))
|
||||
if contain_oov(token_sym_table, chars):
|
||||
continue
|
||||
lexicon.append((word, chars))
|
||||
|
||||
# The OOV word is <UNK>
|
||||
lexicon.append(("<UNK>", ["<unk>"]))
|
||||
return lexicon
|
||||
|
||||
|
||||
def generate_tokens(text_file: str) -> Dict[str, int]:
|
||||
"""Generate tokens from the given text file.
|
||||
|
||||
Args:
|
||||
text_file:
|
||||
A file that contains text lines to generate tokens.
|
||||
Returns:
|
||||
Return a dict whose keys are tokens and values are token ids ranged
|
||||
from 0 to len(keys) - 1.
|
||||
"""
|
||||
tokens: Dict[str, int] = dict()
|
||||
tokens["<blk>"] = 0
|
||||
tokens["<sos/eos>"] = 1
|
||||
tokens["<unk>"] = 2
|
||||
whitespace = re.compile(r"([ \t\r\n]+)")
|
||||
with open(text_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = re.sub(whitespace, "", line)
|
||||
chars = list(line)
|
||||
for char in chars:
|
||||
if char not in tokens:
|
||||
tokens[char] = len(tokens)
|
||||
return tokens
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain the bpe.model and words.txt
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
text_file = lang_dir / "text"
|
||||
|
||||
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
words = word_sym_table.symbols
|
||||
|
||||
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
|
||||
for w in excluded:
|
||||
if w in words:
|
||||
words.remove(w)
|
||||
|
||||
token_sym_table = generate_tokens(text_file)
|
||||
|
||||
lexicon = generate_lexicon(token_sym_table, words)
|
||||
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
next_token_id = max(token_sym_table.values()) + 1
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in token_sym_table
|
||||
token_sym_table[disambig] = next_token_id
|
||||
next_token_id += 1
|
||||
|
||||
word_sym_table.add("#0")
|
||||
word_sym_table.add("<s>")
|
||||
word_sym_table.add("</s>")
|
||||
|
||||
write_mapping(lang_dir / "tokens.txt", token_sym_table)
|
||||
|
||||
write_lexicon(lang_dir / "lexicon.txt", lexicon)
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst_no_sil(
|
||||
lexicon,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst_no_sil(
|
||||
lexicon_disambig,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
388
egs/librispeech/SSL/local/prepare_lang.py
Normal file
388
egs/librispeech/SSL/local/prepare_lang.py
Normal file
@ -0,0 +1,388 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
|
||||
consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
|
||||
|
||||
2. Generate tokens.txt, the token table mapping a token to a unique integer.
|
||||
|
||||
3. Generate words.txt, the word table mapping a word to a unique integer.
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
"""
|
||||
import argparse
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import read_lexicon, write_lexicon
|
||||
|
||||
Lexicon = List[Tuple[str, List[str]]]
|
||||
|
||||
|
||||
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||
"""Write a symbol to ID mapping to a file.
|
||||
|
||||
Note:
|
||||
No need to implement `read_mapping` as it can be done
|
||||
through :func:`k2.SymbolTable.from_file`.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename to save the mapping.
|
||||
sym2id:
|
||||
A dict mapping symbols to IDs.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
for sym, i in sym2id.items():
|
||||
f.write(f"{sym} {i}\n")
|
||||
|
||||
|
||||
def get_tokens(lexicon: Lexicon) -> List[str]:
|
||||
"""Get tokens from a lexicon.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is the return value of :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a list of unique tokens.
|
||||
"""
|
||||
ans = set()
|
||||
for _, tokens in lexicon:
|
||||
ans.update(tokens)
|
||||
sorted_ans = sorted(list(ans))
|
||||
return sorted_ans
|
||||
|
||||
|
||||
def get_words(lexicon: Lexicon) -> List[str]:
|
||||
"""Get words from a lexicon.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is the return value of :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a list of unique words.
|
||||
"""
|
||||
ans = set()
|
||||
for word, _ in lexicon:
|
||||
ans.add(word)
|
||||
sorted_ans = sorted(list(ans))
|
||||
return sorted_ans
|
||||
|
||||
|
||||
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
|
||||
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
|
||||
at the ends of tokens to ensure that all pronunciations are different,
|
||||
and that none is a prefix of another.
|
||||
|
||||
See also add_lex_disambig.pl from kaldi.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is returned by :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a tuple with two elements:
|
||||
|
||||
- The output lexicon with disambiguation symbols
|
||||
- The ID of the max disambiguation symbol that appears
|
||||
in the lexicon
|
||||
"""
|
||||
|
||||
# (1) Work out the count of each token-sequence in the
|
||||
# lexicon.
|
||||
count = defaultdict(int)
|
||||
for _, tokens in lexicon:
|
||||
count[" ".join(tokens)] += 1
|
||||
|
||||
# (2) For each left sub-sequence of each token-sequence, note down
|
||||
# that it exists (for identifying prefixes of longer strings).
|
||||
issubseq = defaultdict(int)
|
||||
for _, tokens in lexicon:
|
||||
tokens = tokens.copy()
|
||||
tokens.pop()
|
||||
while tokens:
|
||||
issubseq[" ".join(tokens)] = 1
|
||||
tokens.pop()
|
||||
|
||||
# (3) For each entry in the lexicon:
|
||||
# if the token sequence is unique and is not a
|
||||
# prefix of another word, no disambig symbol.
|
||||
# Else output #1, or #2, #3, ... if the same token-seq
|
||||
# has already been assigned a disambig symbol.
|
||||
ans = []
|
||||
|
||||
# We start with #1 since #0 has its own purpose
|
||||
first_allowed_disambig = 1
|
||||
max_disambig = first_allowed_disambig - 1
|
||||
last_used_disambig_symbol_of = defaultdict(int)
|
||||
|
||||
for word, tokens in lexicon:
|
||||
tokenseq = " ".join(tokens)
|
||||
assert tokenseq != ""
|
||||
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
|
||||
ans.append((word, tokens))
|
||||
continue
|
||||
|
||||
cur_disambig = last_used_disambig_symbol_of[tokenseq]
|
||||
if cur_disambig == 0:
|
||||
cur_disambig = first_allowed_disambig
|
||||
else:
|
||||
cur_disambig += 1
|
||||
|
||||
if cur_disambig > max_disambig:
|
||||
max_disambig = cur_disambig
|
||||
last_used_disambig_symbol_of[tokenseq] = cur_disambig
|
||||
tokenseq += f" #{cur_disambig}"
|
||||
ans.append((word, tokenseq.split()))
|
||||
return ans, max_disambig
|
||||
|
||||
|
||||
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
|
||||
"""Generate ID maps, i.e., map a symbol to a unique ID.
|
||||
|
||||
Args:
|
||||
symbols:
|
||||
A list of unique symbols.
|
||||
Returns:
|
||||
A dict containing the mapping between symbols and IDs.
|
||||
"""
|
||||
return {sym: i for i, sym in enumerate(symbols)}
|
||||
|
||||
|
||||
def add_self_loops(
|
||||
arcs: List[List[Any]], disambig_token: int, disambig_word: int
|
||||
) -> List[List[Any]]:
|
||||
"""Adds self-loops to states of an FST to propagate disambiguation symbols
|
||||
through it. They are added on each state with non-epsilon output symbols
|
||||
on at least one arc out of the state.
|
||||
|
||||
See also fstaddselfloops.pl from Kaldi. One difference is that
|
||||
Kaldi uses OpenFst style FSTs and it has multiple final states.
|
||||
This function uses k2 style FSTs and it does not need to add self-loops
|
||||
to the final state.
|
||||
|
||||
The input label of a self-loop is `disambig_token`, while the output
|
||||
label is `disambig_word`.
|
||||
|
||||
Args:
|
||||
arcs:
|
||||
A list-of-list. The sublist contains
|
||||
`[src_state, dest_state, label, aux_label, score]`
|
||||
disambig_token:
|
||||
It is the token ID of the symbol `#0`.
|
||||
disambig_word:
|
||||
It is the word ID of the symbol `#0`.
|
||||
|
||||
Return:
|
||||
Return new `arcs` containing self-loops.
|
||||
"""
|
||||
states_needs_self_loops = set()
|
||||
for arc in arcs:
|
||||
src, dst, ilabel, olabel, score = arc
|
||||
if olabel != 0:
|
||||
states_needs_self_loops.add(src)
|
||||
|
||||
ans = []
|
||||
for s in states_needs_self_loops:
|
||||
ans.append([s, s, disambig_token, disambig_word, 0])
|
||||
|
||||
return arcs + ans
|
||||
|
||||
|
||||
def lexicon_to_fst(
|
||||
lexicon: Lexicon,
|
||||
token2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
sil_token: str = "SIL",
|
||||
sil_prob: float = 0.5,
|
||||
need_self_loops: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format) with optional silence at
|
||||
the beginning and end of each word.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
token2id:
|
||||
A dict mapping tokens to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
sil_token:
|
||||
The silence token.
|
||||
sil_prob:
|
||||
The probability for adding a silence at the beginning and end
|
||||
of the word.
|
||||
need_self_loops:
|
||||
If True, add self-loop to states with non-epsilon output symbols
|
||||
on at least one arc out of the state. The input label for this
|
||||
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
assert sil_prob > 0.0 and sil_prob < 1.0
|
||||
# CAUTION: we use score, i.e, negative cost.
|
||||
sil_score = math.log(sil_prob)
|
||||
no_sil_score = math.log(1.0 - sil_prob)
|
||||
|
||||
start_state = 0
|
||||
loop_state = 1 # words enter and leave from here
|
||||
sil_state = 2 # words terminate here when followed by silence; this state
|
||||
# has a silence transition to loop_state.
|
||||
next_state = 3 # the next un-allocated state, will be incremented as we go.
|
||||
arcs = []
|
||||
|
||||
assert token2id["<eps>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
|
||||
sil_token = token2id[sil_token]
|
||||
|
||||
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
|
||||
arcs.append([start_state, sil_state, eps, eps, sil_score])
|
||||
arcs.append([sil_state, loop_state, sil_token, eps, 0])
|
||||
|
||||
for word, tokens in lexicon:
|
||||
assert len(tokens) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
tokens = [token2id[i] for i in tokens]
|
||||
|
||||
for i in range(len(tokens) - 1):
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, next_state, tokens[i], w, 0])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last token of this word
|
||||
# It has two out-going arcs, one to the loop state,
|
||||
# the other one to the sil_state.
|
||||
i = len(tokens) - 1
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
|
||||
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
|
||||
|
||||
if need_self_loops:
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
out_dir = Path(get_args().lang_dir)
|
||||
lexicon_filename = out_dir / "lexicon.txt"
|
||||
sil_token = "SIL"
|
||||
sil_prob = 0.5
|
||||
|
||||
lexicon = read_lexicon(lexicon_filename)
|
||||
tokens = get_tokens(lexicon)
|
||||
words = get_words(lexicon)
|
||||
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in tokens
|
||||
tokens.append(f"#{i}")
|
||||
|
||||
assert "<eps>" not in tokens
|
||||
tokens = ["<eps>"] + tokens
|
||||
|
||||
assert "<eps>" not in words
|
||||
assert "#0" not in words
|
||||
assert "<s>" not in words
|
||||
assert "</s>" not in words
|
||||
|
||||
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
|
||||
|
||||
token2id = generate_id_map(tokens)
|
||||
word2id = generate_id_map(words)
|
||||
|
||||
write_mapping(out_dir / "tokens.txt", token2id)
|
||||
write_mapping(out_dir / "words.txt", word2id)
|
||||
write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst(
|
||||
lexicon,
|
||||
token2id=token2id,
|
||||
word2id=word2id,
|
||||
sil_token=sil_token,
|
||||
sil_prob=sil_prob,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst(
|
||||
lexicon_disambig,
|
||||
token2id=token2id,
|
||||
word2id=word2id,
|
||||
sil_token=sil_token,
|
||||
sil_prob=sil_prob,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), out_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
|
||||
|
||||
if False:
|
||||
# Just for debugging, will remove it
|
||||
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
|
||||
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
|
||||
L_disambig.labels_sym = L.labels_sym
|
||||
L_disambig.aux_labels_sym = L.aux_labels_sym
|
||||
L.draw(out_dir / "L.png", title="L")
|
||||
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
107
egs/librispeech/SSL/local/process_librispeech4finetune.py
Normal file
107
egs/librispeech/SSL/local/process_librispeech4finetune.py
Normal file
@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def process_wav_librispeech(
|
||||
dataset: Optional[str] = None,
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/wav")
|
||||
|
||||
if dataset is None:
|
||||
dataset_parts = (
|
||||
"dev-clean",
|
||||
"dev-other",
|
||||
"test-clean",
|
||||
"test-other",
|
||||
"train-clean-100",
|
||||
"train-clean-360",
|
||||
"train-other-500",
|
||||
)
|
||||
else:
|
||||
dataset_parts = dataset.split(" ", -1)
|
||||
|
||||
prefix = "librispeech"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
for partition, m in manifests.items():
|
||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
if (output_dir / cuts_filename).is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
cut_set.to_file(output_dir / cuts_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
process_wav_librispeech(
|
||||
dataset=args.dataset,
|
||||
)
|
104
egs/librispeech/SSL/local/process_librispeech4pretrain.py
Normal file
104
egs/librispeech/SSL/local/process_librispeech4pretrain.py
Normal file
@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def process_kmeans_librispeech(
|
||||
dataset: Optional[str] = None,
|
||||
):
|
||||
src_dir = Path(".")
|
||||
output_dir = Path(".")
|
||||
|
||||
if dataset is None:
|
||||
dataset_parts = (
|
||||
"dev-clean",
|
||||
"train-clean-100",
|
||||
"train-clean-360",
|
||||
"train-other-500",
|
||||
)
|
||||
else:
|
||||
dataset_parts = dataset.split(" ", -1)
|
||||
|
||||
prefix = "librispeech"
|
||||
suffix = "jsonl"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
for partition, m in manifests.items():
|
||||
cuts_filename = f"{prefix}_cuts_{partition}_raw.{suffix}"
|
||||
if (output_dir / cuts_filename).is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
cut_set.to_file(output_dir / cuts_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
process_kmeans_librispeech(
|
||||
dataset=args.dataset,
|
||||
)
|
23
egs/librispeech/SSL/local/process_raw_cuts.py
Normal file
23
egs/librispeech/SSL/local/process_raw_cuts.py
Normal file
@ -0,0 +1,23 @@
|
||||
import os
|
||||
|
||||
import jsonlines
|
||||
from tqdm import tqdm
|
||||
|
||||
dataset_parts = (
|
||||
"dev-clean",
|
||||
"train-clean-100",
|
||||
"train-clean-360",
|
||||
"train-other-500",
|
||||
)
|
||||
|
||||
for part in dataset_parts:
|
||||
with jsonlines.open(f"librispeech_cuts_{part}_raw.jsonl") as reader:
|
||||
with jsonlines.open(f"librispeech_cuts_{part}.jsonl", mode="w") as writer:
|
||||
for obj in tqdm(reader):
|
||||
obj["custom"] = {"kmeans": obj["supervisions"][0]["custom"]["kmeans"]}
|
||||
del obj["supervisions"][0]["custom"]
|
||||
|
||||
writer.write(obj)
|
||||
|
||||
os.system("rm *_raw.jsonl")
|
||||
os.system("gzip *.jsonl")
|
1
egs/librispeech/SSL/shared
Symbolic link
1
egs/librispeech/SSL/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared
|
1
egs/librispeech/SSL/zipformer/asr_datamodule.py
Symbolic link
1
egs/librispeech/SSL/zipformer/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../hubert/asr_datamodule.py
|
1
egs/librispeech/SSL/zipformer/beam_search.py
Symbolic link
1
egs/librispeech/SSL/zipformer/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/beam_search.py
|
1
egs/librispeech/SSL/zipformer/dataset.py
Symbolic link
1
egs/librispeech/SSL/zipformer/dataset.py
Symbolic link
@ -0,0 +1 @@
|
||||
../hubert/dataset.py
|
1043
egs/librispeech/SSL/zipformer/decode.py
Normal file
1043
egs/librispeech/SSL/zipformer/decode.py
Normal file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/SSL/zipformer/decoder.py
Symbolic link
1
egs/librispeech/SSL/zipformer/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/decoder.py
|
1
egs/librispeech/SSL/zipformer/encoder_interface.py
Symbolic link
1
egs/librispeech/SSL/zipformer/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/encoder_interface.py
|
1551
egs/librispeech/SSL/zipformer/finetune.py
Normal file
1551
egs/librispeech/SSL/zipformer/finetune.py
Normal file
File diff suppressed because it is too large
Load Diff
601
egs/librispeech/SSL/zipformer/hubert_ce.py
Normal file
601
egs/librispeech/SSL/zipformer/hubert_ce.py
Normal file
@ -0,0 +1,601 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scaling import ScheduledFloat
|
||||
from utils import GradMultiply, LayerNorm
|
||||
from wav2vec2_module import ConvFeatureExtractionModel
|
||||
from zipformer import Zipformer2
|
||||
|
||||
|
||||
def compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
mask_type: str = "static",
|
||||
mask_other: float = 0.0,
|
||||
min_masks: int = 0,
|
||||
no_overlap: bool = False,
|
||||
min_space: int = 0,
|
||||
require_same_masks: bool = True,
|
||||
mask_dropout: float = 0.0,
|
||||
add_masks: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
epoch: Optional[int] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
|
||||
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes random mask spans for a given shape
|
||||
|
||||
Args:
|
||||
shape: the the shape for which to compute masks.
|
||||
should be of size 2 where first element is batch size and 2nd is timesteps
|
||||
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||||
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||||
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||
mask_type: how to compute mask lengths
|
||||
static = fixed size
|
||||
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
||||
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
||||
poisson = sample from possion distribution with lambda = mask length
|
||||
min_masks: minimum number of masked spans
|
||||
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
||||
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
||||
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
||||
mask_dropout: randomly dropout this percentage of masks in each example
|
||||
"""
|
||||
|
||||
bsz, all_sz = shape
|
||||
mask = np.full((bsz, all_sz), False)
|
||||
|
||||
if num_mask_ver == 1:
|
||||
all_num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * all_sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
all_num_mask = max(min_masks, all_num_mask)
|
||||
|
||||
mask_idcs = []
|
||||
for i in range(bsz):
|
||||
if seed is not None and epoch is not None and indices is not None:
|
||||
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
|
||||
else:
|
||||
seed_i = None
|
||||
|
||||
rng = np.random.default_rng(seed_i)
|
||||
|
||||
if padding_mask is not None:
|
||||
sz = all_sz - padding_mask[i].long().sum().item()
|
||||
assert sz >= 0, sz
|
||||
else:
|
||||
sz = all_sz
|
||||
|
||||
if num_mask_ver == 1:
|
||||
if padding_mask is not None:
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
num_mask = all_num_mask
|
||||
elif num_mask_ver == 2:
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ rng.random()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
if mask_type == "static":
|
||||
lengths = np.full(num_mask, mask_length)
|
||||
elif mask_type == "uniform":
|
||||
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
||||
elif mask_type == "normal":
|
||||
lengths = rng.normal(mask_length, mask_other, size=num_mask)
|
||||
lengths = [max(1, int(round(x))) for x in lengths]
|
||||
elif mask_type == "poisson":
|
||||
lengths = rng.poisson(mask_length, size=num_mask)
|
||||
lengths = [int(round(x)) for x in lengths]
|
||||
else:
|
||||
raise Exception("unknown mask selection " + mask_type)
|
||||
|
||||
if sum(lengths) == 0:
|
||||
if mask_type == "static":
|
||||
raise ValueError(f"this should never happens")
|
||||
else:
|
||||
lengths = [min(mask_length, sz - 1)]
|
||||
|
||||
if no_overlap:
|
||||
mask_idc = []
|
||||
|
||||
def arrange(s, e, length, keep_length):
|
||||
span_start = rng.randint(s, e - length)
|
||||
mask_idc.extend(span_start + i for i in range(length))
|
||||
|
||||
new_parts = []
|
||||
if span_start - s - min_space >= keep_length:
|
||||
new_parts.append((s, span_start - min_space + 1))
|
||||
if e - span_start - length - min_space > keep_length:
|
||||
new_parts.append((span_start + length + min_space, e))
|
||||
return new_parts
|
||||
|
||||
parts = [(0, sz)]
|
||||
min_length = min(lengths)
|
||||
for length in sorted(lengths, reverse=True):
|
||||
lens = np.fromiter(
|
||||
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
||||
np.int,
|
||||
)
|
||||
l_sum = np.sum(lens)
|
||||
if l_sum == 0:
|
||||
break
|
||||
probs = lens / np.sum(lens)
|
||||
c = rng.choice(len(parts), p=probs)
|
||||
s, e = parts.pop(c)
|
||||
parts.extend(arrange(s, e, length, min_length))
|
||||
mask_idc = np.asarray(mask_idc)
|
||||
else:
|
||||
if idc_select_ver == 1:
|
||||
min_len = min(lengths)
|
||||
if sz - min_len <= num_mask:
|
||||
min_len = sz - num_mask - 1
|
||||
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
|
||||
elif idc_select_ver == 2:
|
||||
mask_idc = rng.choice(sz, num_mask, replace=False)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
mask_idc = np.asarray(
|
||||
[
|
||||
mask_idc[j] + offset
|
||||
for j in range(len(mask_idc))
|
||||
for offset in range(lengths[j])
|
||||
]
|
||||
)
|
||||
|
||||
mask_idc = np.unique(mask_idc[mask_idc < sz])
|
||||
if len(mask_idc) >= sz:
|
||||
raise ValueError(
|
||||
(
|
||||
f"the entire sequence is masked. "
|
||||
f"sz={sz}; mask_idc[mask_idc]; "
|
||||
f"index={indices[i] if indices is not None else None}"
|
||||
)
|
||||
)
|
||||
mask_idcs.append(mask_idc)
|
||||
|
||||
target_len = None
|
||||
if require_same_masks:
|
||||
if add_masks:
|
||||
target_len = max([len(m) for m in mask_idcs])
|
||||
else:
|
||||
target_len = min([len(m) for m in mask_idcs])
|
||||
|
||||
for i, mask_idc in enumerate(mask_idcs):
|
||||
if target_len is not None and len(mask_idc) > target_len:
|
||||
mask_idc = rng.choice(mask_idc, target_len, replace=False)
|
||||
|
||||
mask[i, mask_idc] = True
|
||||
|
||||
if target_len is not None and len(mask_idc) < target_len:
|
||||
unmasked = np.flatnonzero(~mask[i])
|
||||
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
|
||||
mask[i, to_mask] = True
|
||||
|
||||
if mask_dropout > 0:
|
||||
masked = np.flatnonzero(mask[i])
|
||||
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
|
||||
to_drop = rng.choice(masked, num_holes, replace=False)
|
||||
mask[i, to_drop] = False
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def _to_int_tuple(s: str):
|
||||
return tuple(map(int, s.split(",")))
|
||||
|
||||
|
||||
class HubertModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
|
||||
self.embed = feature_enc_layers[-1][0]
|
||||
|
||||
self.feature_extractor = ConvFeatureExtractionModel(
|
||||
conv_layers=feature_enc_layers,
|
||||
dropout=0.0,
|
||||
mode=cfg.extractor_mode,
|
||||
conv_bias=cfg.conv_bias,
|
||||
)
|
||||
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
||||
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate
|
||||
encoder_input_dim = _to_int_tuple(cfg.encoder_dim)[0]
|
||||
encoder_output_dim = max(_to_int_tuple(cfg.encoder_dim))
|
||||
self.post_extract_proj = (
|
||||
nn.Linear(self.embed, encoder_input_dim)
|
||||
if self.embed != encoder_input_dim
|
||||
else None
|
||||
)
|
||||
|
||||
self.mask_prob = cfg.mask_prob
|
||||
self.mask_selection = cfg.mask_selection
|
||||
self.mask_other = cfg.mask_other
|
||||
self.mask_length = cfg.mask_length
|
||||
self.no_mask_overlap = cfg.no_mask_overlap
|
||||
self.mask_min_space = cfg.mask_min_space
|
||||
|
||||
self.mask_channel_prob = cfg.mask_channel_prob
|
||||
self.mask_channel_selection = cfg.mask_channel_selection
|
||||
self.mask_channel_other = cfg.mask_channel_other
|
||||
self.mask_channel_length = cfg.mask_channel_length
|
||||
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
||||
self.mask_channel_min_space = cfg.mask_channel_min_space
|
||||
|
||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
||||
|
||||
self.feature_grad_mult = cfg.feature_grad_mult
|
||||
self.logit_temp = cfg.logit_temp
|
||||
self.skip_masked = cfg.skip_masked
|
||||
self.skip_nomask = cfg.skip_nomask
|
||||
|
||||
self.mask_emb = nn.Parameter(torch.FloatTensor(encoder_input_dim).uniform_())
|
||||
|
||||
self.encoder = Zipformer2(
|
||||
output_downsampling_factor=1,
|
||||
downsampling_factor=_to_int_tuple(cfg.downsampling_factor),
|
||||
num_encoder_layers=_to_int_tuple(cfg.num_encoder_layers),
|
||||
encoder_dim=_to_int_tuple(cfg.encoder_dim),
|
||||
encoder_unmasked_dim=_to_int_tuple(cfg.encoder_unmasked_dim),
|
||||
query_head_dim=_to_int_tuple(cfg.query_head_dim),
|
||||
pos_head_dim=_to_int_tuple(cfg.pos_head_dim),
|
||||
value_head_dim=_to_int_tuple(cfg.value_head_dim),
|
||||
pos_dim=cfg.pos_dim,
|
||||
num_heads=_to_int_tuple(cfg.num_heads),
|
||||
feedforward_dim=_to_int_tuple(cfg.feedforward_dim),
|
||||
cnn_module_kernel=_to_int_tuple(cfg.cnn_module_kernel),
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
)
|
||||
|
||||
self.layer_norm = LayerNorm(self.embed)
|
||||
|
||||
self.untie_final_proj = cfg.untie_final_proj
|
||||
self.final_proj = nn.Linear(encoder_output_dim, sum(cfg.num_classes))
|
||||
|
||||
# modules below are not needed during fine-tuning
|
||||
self.num_classes = cfg.num_classes
|
||||
self.pred_masked_weight = cfg.pred_masked_weight
|
||||
self.pred_nomask_weight = cfg.pred_nomask_weight
|
||||
self.loss_weights = cfg.loss_weights
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
||||
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
return state_dict
|
||||
|
||||
def apply_mask(self, x, padding_mask, target_list):
|
||||
B, T, C = x.shape
|
||||
if self.mask_prob > 0:
|
||||
mask_indices = compute_mask_indices(
|
||||
(B, T),
|
||||
padding_mask,
|
||||
self.mask_prob,
|
||||
self.mask_length,
|
||||
self.mask_selection,
|
||||
self.mask_other,
|
||||
min_masks=2,
|
||||
no_overlap=self.no_mask_overlap,
|
||||
min_space=self.mask_min_space,
|
||||
)
|
||||
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
||||
x[mask_indices] = self.mask_emb.to(x.dtype)
|
||||
else:
|
||||
mask_indices = None
|
||||
|
||||
if self.mask_channel_prob > 0:
|
||||
mask_channel_indices = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
self.mask_channel_prob,
|
||||
self.mask_channel_length,
|
||||
self.mask_channel_selection,
|
||||
self.mask_channel_other,
|
||||
no_overlap=self.no_mask_channel_overlap,
|
||||
min_space=self.mask_channel_min_space,
|
||||
)
|
||||
mask_channel_indices = (
|
||||
torch.from_numpy(mask_channel_indices)
|
||||
.to(x.device)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, T, -1)
|
||||
)
|
||||
x[mask_channel_indices] = 0
|
||||
|
||||
return x, mask_indices
|
||||
|
||||
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
|
||||
if self.feature_grad_mult > 0:
|
||||
features = self.feature_extractor(source)
|
||||
if self.feature_grad_mult != 1.0:
|
||||
features = GradMultiply.apply(features, self.feature_grad_mult)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
features = self.feature_extractor(source)
|
||||
return features
|
||||
|
||||
def forward_targets(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
target_list: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Trim features to ensure labels exist and then get aligned labels
|
||||
feat_tsz = features.size(2)
|
||||
targ_tsz = min([t.size(1) for t in target_list])
|
||||
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
||||
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
||||
features = features[..., :feat_tsz]
|
||||
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
||||
target_list = [t[:, target_inds.long()] for t in target_list]
|
||||
return features, target_list
|
||||
|
||||
def forward_padding_mask(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
extra = padding_mask.size(1) % features.size(1)
|
||||
if extra > 0:
|
||||
padding_mask = padding_mask[:, :-extra]
|
||||
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
||||
padding_mask = padding_mask.all(-1)
|
||||
return padding_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
target_list: Optional[List[torch.Tensor]] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
mask: bool = True,
|
||||
features_only: bool = False,
|
||||
output_layer: Optional[int] = None,
|
||||
):
|
||||
"""output layer is 1-based"""
|
||||
features = self.forward_features(source)
|
||||
if target_list is not None:
|
||||
features, target_list = self.forward_targets(features, target_list)
|
||||
|
||||
features_pen = features.float().pow(2).mean()
|
||||
|
||||
features = features.transpose(1, 2)
|
||||
features = self.layer_norm(features)
|
||||
unmasked_features = features.clone()
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||
|
||||
if self.post_extract_proj is not None:
|
||||
features = self.post_extract_proj(features)
|
||||
|
||||
features = self.dropout_input(features)
|
||||
unmasked_features = self.dropout_features(unmasked_features)
|
||||
|
||||
if mask:
|
||||
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
|
||||
else:
|
||||
x = features
|
||||
mask_indices = None
|
||||
|
||||
# feature: (B, T, D), float
|
||||
# target: (B, T), long
|
||||
# x: (B, T, D), float -> (T, B, D), float
|
||||
# padding_mask: (B, T), bool
|
||||
# mask_indices: (B, T), bool
|
||||
x = x.transpose(0, 1)
|
||||
x, x_lens = self.encoder(x, ~padding_mask.sum(dim=-1))
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
if features_only:
|
||||
return {"x": x, "padding_mask": padding_mask, "features": features}
|
||||
|
||||
if not self.skip_masked:
|
||||
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
||||
proj_x_m = self.final_proj(x[masked_indices])
|
||||
proj_x_m /= self.logit_temp
|
||||
logit_m_list = [proj_x_m for _ in range(len(target_list))]
|
||||
else:
|
||||
logit_m_list = [None for _ in target_list]
|
||||
|
||||
if not self.skip_nomask:
|
||||
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
|
||||
proj_x_u = self.final_proj(x[nomask_indices])
|
||||
proj_x_u /= self.logit_temp
|
||||
logit_u_list = [proj_x_u for _ in range(len(target_list))]
|
||||
else:
|
||||
logit_u_list = [None for _ in target_list]
|
||||
|
||||
# result = {
|
||||
# "logit_m_list": logit_m_list,
|
||||
# "logit_u_list": logit_u_list,
|
||||
# "padding_mask": padding_mask,
|
||||
# "features_pen": features_pen,
|
||||
# }
|
||||
targ_m_list = target_list[0][masked_indices]
|
||||
targ_m_list = targ_m_list.long()
|
||||
targ_m_list = [targ_m_list for _ in range(len(target_list))]
|
||||
|
||||
targ_u_list = target_list[0][nomask_indices]
|
||||
targ_u_list = targ_u_list.long()
|
||||
targ_u_list = [targ_u_list for _ in range(len(target_list))]
|
||||
return self.compute_loss(
|
||||
logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
|
||||
)
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
mask: bool = False,
|
||||
ret_conv: bool = False,
|
||||
output_layer: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
res = self.forward(
|
||||
source,
|
||||
padding_mask=padding_mask,
|
||||
mask=mask,
|
||||
features_only=True,
|
||||
output_layer=output_layer,
|
||||
)
|
||||
feature = res["features"] if ret_conv else res["x"]
|
||||
return feature, res["padding_mask"]
|
||||
|
||||
def get_logits(self, net_output, is_masked=True):
|
||||
if is_masked:
|
||||
logits_list = net_output["logit_m_list"]
|
||||
else:
|
||||
logits_list = net_output["logit_u_list"]
|
||||
logits_list = [x.float() for x in logits_list if x is not None]
|
||||
return logits_list
|
||||
|
||||
def get_targets(self, net_output, is_masked=True):
|
||||
logits_list = self.get_logits(net_output, is_masked)
|
||||
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
|
||||
return targets_list
|
||||
|
||||
def get_extra_losses(self, net_output):
|
||||
extra_losses = []
|
||||
names = []
|
||||
|
||||
if "features_pen" in net_output:
|
||||
extra_losses.append(net_output["features_pen"])
|
||||
names.append("features_pen")
|
||||
|
||||
return extra_losses, names
|
||||
|
||||
def remove_pretraining_modules(self):
|
||||
self.final_proj = None
|
||||
|
||||
def compute_loss(
|
||||
self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
|
||||
):
|
||||
loss = 0.0
|
||||
sample_size = 0
|
||||
logging_output = {}
|
||||
reduce = True
|
||||
reduction = "sum" if reduce else "none"
|
||||
|
||||
loss_m_list = []
|
||||
logp_m_list = [x.float() for x in logit_m_list if x is not None]
|
||||
logp_m_list = torch.cat(logp_m_list)
|
||||
targ_m_list = torch.cat(targ_m_list)
|
||||
|
||||
loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction)
|
||||
loss_m_list.append(loss_m)
|
||||
logging_output[f"loss_m_0"] = loss_m.detach().item()
|
||||
|
||||
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
|
||||
if self.pred_masked_weight > 0:
|
||||
loss += self.pred_masked_weight * sum(loss_m_list)
|
||||
sample_size += len(targ_m_list)
|
||||
|
||||
loss_u_list = []
|
||||
logp_u_list = [x.float() for x in logit_u_list if x is not None]
|
||||
logp_u_list = torch.cat(logp_u_list)
|
||||
targ_u_list = torch.cat(targ_u_list)
|
||||
|
||||
loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction)
|
||||
loss_u_list.append(loss_u)
|
||||
logging_output[f"loss_u_0"] = loss_u.detach().item()
|
||||
|
||||
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
|
||||
if self.pred_nomask_weight > 0:
|
||||
loss += self.pred_nomask_weight * sum(loss_u_list)
|
||||
sample_size += len(targ_u_list)
|
||||
|
||||
if self.loss_weights is not None:
|
||||
extra_losses = []
|
||||
names = []
|
||||
extra_losses.append(features_pen)
|
||||
names.append("features_pen")
|
||||
if torch.is_tensor(extra_losses):
|
||||
extra_losses = [extra_losses]
|
||||
names = [names]
|
||||
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
||||
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
||||
assert len(extra_losses) == len(
|
||||
self.loss_weights
|
||||
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
||||
for p, n, coef in zip(extra_losses, names, self.loss_weights):
|
||||
if coef != 0 and p is not None:
|
||||
p = coef * p.float() * sample_size
|
||||
loss += p
|
||||
logging_output[f"loss_{n}"] = p.item()
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.item() if reduce else loss,
|
||||
**logging_output,
|
||||
}
|
||||
|
||||
# for lk in self.log_keys:
|
||||
# if lk in net_output:
|
||||
# logging_output[lk] = float((net_output[lk]))
|
||||
|
||||
def compute_correct(logits, target):
|
||||
if logits.numel() == 0:
|
||||
return 0, 0
|
||||
else:
|
||||
assert logits.dim() > 1, logits.shape
|
||||
max = logits.argmax(-1) == target
|
||||
min = logits.argmin(-1) == target
|
||||
both = max & min
|
||||
corr = max.long().sum().item() - both.long().sum().item()
|
||||
count = max.numel()
|
||||
return corr, count
|
||||
|
||||
with torch.no_grad():
|
||||
corr_m, count_m = compute_correct(logp_m_list, targ_m_list)
|
||||
logging_output[f"correct_m_0"] = corr_m
|
||||
logging_output[f"count_m_0"] = count_m
|
||||
|
||||
corr_u, count_u = compute_correct(logp_u_list, targ_u_list)
|
||||
logging_output[f"correct_u_0"] = corr_u
|
||||
logging_output[f"count_u_0"] = count_u
|
||||
|
||||
return loss, sample_size, logging_output
|
1
egs/librispeech/SSL/zipformer/joiner.py
Symbolic link
1
egs/librispeech/SSL/zipformer/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/joiner.py
|
344
egs/librispeech/SSL/zipformer/model.py
Normal file
344
egs/librispeech/SSL/zipformer/model.py
Normal file
@ -0,0 +1,344 @@
|
||||
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Zengwei Yao,
|
||||
# Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder,
|
||||
decoder: Optional[nn.Module] = None,
|
||||
joiner: Optional[nn.Module] = None,
|
||||
encoder_dim: int = 768,
|
||||
decoder_dim: int = 512,
|
||||
vocab_size: int = 500,
|
||||
use_transducer: bool = True,
|
||||
use_ctc: bool = False,
|
||||
):
|
||||
"""A joint CTC & Transducer ASR model.
|
||||
|
||||
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
|
||||
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
|
||||
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
|
||||
|
||||
Args:
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
inputs: `x` of (N, T, encoder_dim).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
It is used when use_transducer is True.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
It is used when use_transducer is True.
|
||||
use_transducer:
|
||||
Whether use transducer head. Default: True.
|
||||
use_ctc:
|
||||
Whether use CTC head. Default: False.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
use_transducer or use_ctc
|
||||
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
|
||||
|
||||
self.encoder = encoder
|
||||
|
||||
self.use_transducer = use_transducer
|
||||
if use_transducer:
|
||||
# Modules for Transducer head
|
||||
assert decoder is not None
|
||||
assert hasattr(decoder, "blank_id")
|
||||
assert joiner is not None
|
||||
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim, vocab_size, initial_scale=0.25
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(
|
||||
decoder_dim, vocab_size, initial_scale=0.25
|
||||
)
|
||||
else:
|
||||
assert decoder is None
|
||||
assert joiner is None
|
||||
|
||||
self.use_ctc = use_ctc
|
||||
if use_ctc:
|
||||
# Modules for CTC head
|
||||
self.ctc_output = nn.Sequential(
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(encoder_dim, vocab_size),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
def forward_encoder(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute encoder outputs.
|
||||
Args:
|
||||
x:
|
||||
A 2-D tensor of shape (N, T).
|
||||
|
||||
Returns:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
"""
|
||||
if padding_mask is None:
|
||||
padding_mask = torch.zeros_like(x, dtype=torch.bool)
|
||||
|
||||
encoder_out, padding_mask = self.encoder.extract_features(
|
||||
source=x,
|
||||
padding_mask=padding_mask,
|
||||
mask=self.encoder.training,
|
||||
)
|
||||
encoder_out_lens = torch.sum(~padding_mask, dim=1)
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def forward_ctc(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
targets:
|
||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||
to be un-padded and concatenated within 1 dimension.
|
||||
"""
|
||||
# Compute CTC log-prob
|
||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||
targets=targets,
|
||||
input_lengths=encoder_out_lens,
|
||||
target_lengths=target_lengths,
|
||||
reduction="sum",
|
||||
)
|
||||
return ctc_loss
|
||||
|
||||
def forward_transducer(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
y_lens: torch.Tensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute Transducer loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
"""
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(encoder_out.size(0), 4),
|
||||
dtype=torch.int64,
|
||||
device=encoder_out.device,
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = encoder_out_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return simple_loss, pruned_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 2-D tensor of shape (N, T).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer losses and CTC loss,
|
||||
in form of (simple_loss, pruned_loss, ctc_loss)
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 2, x.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == y.dim0, (x.shape, y.dim0)
|
||||
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask)
|
||||
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
if self.use_transducer:
|
||||
# Compute transducer loss
|
||||
simple_loss, pruned_loss = self.forward_transducer(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
y=y.to(x.device),
|
||||
y_lens=y_lens,
|
||||
prune_range=prune_range,
|
||||
am_scale=am_scale,
|
||||
lm_scale=lm_scale,
|
||||
)
|
||||
else:
|
||||
simple_loss = torch.empty(0)
|
||||
pruned_loss = torch.empty(0)
|
||||
|
||||
if self.use_ctc:
|
||||
# Compute CTC loss
|
||||
targets = y.values
|
||||
ctc_loss = self.forward_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
)
|
||||
else:
|
||||
ctc_loss = torch.empty(0)
|
||||
|
||||
return simple_loss, pruned_loss, ctc_loss, encoder_out_lens
|
1
egs/librispeech/SSL/zipformer/optim.py
Symbolic link
1
egs/librispeech/SSL/zipformer/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/optim.py
|
1380
egs/librispeech/SSL/zipformer/pretrain.py
Normal file
1380
egs/librispeech/SSL/zipformer/pretrain.py
Normal file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/SSL/zipformer/scaling.py
Symbolic link
1
egs/librispeech/SSL/zipformer/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/scaling.py
|
1
egs/librispeech/SSL/zipformer/ssl_datamodule.py
Symbolic link
1
egs/librispeech/SSL/zipformer/ssl_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../hubert/ssl_datamodule.py
|
337
egs/librispeech/SSL/zipformer/utils.py
Normal file
337
egs/librispeech/SSL/zipformer/utils.py
Normal file
@ -0,0 +1,337 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import math
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def relu_squared(x: torch.Tensor):
|
||||
return F.relu(x).pow(2)
|
||||
|
||||
|
||||
def gelu_accurate(x):
|
||||
if not hasattr(gelu_accurate, "_a"):
|
||||
gelu_accurate._a = math.sqrt(2 / math.pi)
|
||||
return (
|
||||
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
||||
)
|
||||
|
||||
|
||||
def is_xla_tensor(tensor):
|
||||
return torch.is_tensor(tensor) and tensor.device.type == "xla"
|
||||
|
||||
|
||||
def index_put(tensor, indices, value):
|
||||
if is_xla_tensor(tensor):
|
||||
for _ in range(indices.dim(), tensor.dim()):
|
||||
indices = indices.unsqueeze(-1)
|
||||
if indices.size(-1) < tensor.size(-1):
|
||||
indices = indices.expand_as(tensor)
|
||||
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
|
||||
else:
|
||||
tensor[indices] = value
|
||||
return tensor
|
||||
|
||||
|
||||
def pad_to_multiple(x, multiple, dim=-1, value=0):
|
||||
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
|
||||
if x is None:
|
||||
return None, 0
|
||||
tsz = x.size(dim)
|
||||
m = tsz / multiple
|
||||
remainder = math.ceil(m) * multiple - tsz
|
||||
if m.is_integer():
|
||||
return x, 0
|
||||
pad_offset = (0,) * (-1 - dim) * 2
|
||||
|
||||
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
|
||||
|
||||
|
||||
def gelu(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(x.float()).type_as(x)
|
||||
|
||||
|
||||
def get_activation_fn(activation: str) -> Callable:
|
||||
"""Returns the activation function corresponding to `activation`"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "relu_squared":
|
||||
return relu_squared
|
||||
elif activation == "gelu":
|
||||
return gelu
|
||||
elif activation == "gelu_fast":
|
||||
return gelu_accurate
|
||||
elif activation == "gelu_accurate":
|
||||
return gelu_accurate
|
||||
elif activation == "tanh":
|
||||
return torch.tanh
|
||||
elif activation == "linear":
|
||||
return lambda x: x
|
||||
elif activation == "swish":
|
||||
return torch.nn.SiLU
|
||||
else:
|
||||
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||
|
||||
|
||||
class SamePad(nn.Module):
|
||||
def __init__(self, kernel_size, causal=False):
|
||||
super().__init__()
|
||||
if causal:
|
||||
self.remove = kernel_size - 1
|
||||
else:
|
||||
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||
|
||||
def forward(self, x):
|
||||
if self.remove > 0:
|
||||
x = x[:, :, : -self.remove]
|
||||
return x
|
||||
|
||||
|
||||
class SamePad2d(nn.Module):
|
||||
def __init__(self, kernel_size):
|
||||
super().__init__()
|
||||
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x.size()) == 4
|
||||
if self.remove > 0:
|
||||
x = x[:, :, : -self.remove, : -self.remove]
|
||||
return x
|
||||
|
||||
|
||||
class TransposeLast(nn.Module):
|
||||
def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
|
||||
super().__init__()
|
||||
self.deconstruct_idx = deconstruct_idx
|
||||
self.tranpose_dim = tranpose_dim
|
||||
|
||||
def forward(self, x):
|
||||
if self.deconstruct_idx is not None:
|
||||
x = x[self.deconstruct_idx]
|
||||
return x.transpose(self.tranpose_dim, -1)
|
||||
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
|
||||
|
||||
has_fused_layernorm = True
|
||||
|
||||
class FusedLayerNorm(_FusedLayerNorm):
|
||||
@torch.jit.unused
|
||||
def forward(self, x):
|
||||
if not x.is_cuda:
|
||||
return super().forward(x)
|
||||
else:
|
||||
with torch.cuda.device(x.device):
|
||||
return super().forward(x)
|
||||
|
||||
except ImportError:
|
||||
has_fused_layernorm = False
|
||||
|
||||
|
||||
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
export = True
|
||||
if not export and torch.cuda.is_available() and has_fused_layernorm:
|
||||
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
||||
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
||||
|
||||
|
||||
class Fp32LayerNorm(nn.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
output = F.layer_norm(
|
||||
input.float(),
|
||||
self.normalized_shape,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
||||
|
||||
|
||||
class Fp32GroupNorm(nn.GroupNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
output = F.group_norm(
|
||||
input.float(),
|
||||
self.num_groups,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
||||
|
||||
|
||||
def softmax(x, dim: int, onnx_trace: bool = False):
|
||||
if onnx_trace:
|
||||
return F.softmax(x.float(), dim=dim)
|
||||
else:
|
||||
return F.softmax(x, dim=dim, dtype=torch.float32)
|
||||
|
||||
|
||||
def quant_noise(module, p, block_size):
|
||||
"""
|
||||
Wraps modules and applies quantization noise to the weights for
|
||||
subsequent quantization with Iterative Product Quantization as
|
||||
described in "Training with Quantization Noise for Extreme Model Compression"
|
||||
|
||||
Args:
|
||||
- module: nn.Module
|
||||
- p: amount of Quantization Noise
|
||||
- block_size: size of the blocks for subsequent quantization with iPQ
|
||||
|
||||
Remarks:
|
||||
- Module weights must have the right sizes wrt the block size
|
||||
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
||||
- For more detail on how to quantize by blocks with convolutional weights,
|
||||
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
||||
- We implement the simplest form of noise here as stated in the paper
|
||||
which consists in randomly dropping blocks
|
||||
"""
|
||||
|
||||
# if no quantization noise, don't register hook
|
||||
if p <= 0:
|
||||
return module
|
||||
|
||||
# supported modules
|
||||
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
||||
|
||||
# test whether module.weight has the right sizes wrt block_size
|
||||
is_conv = module.weight.ndim == 4
|
||||
|
||||
# 2D matrix
|
||||
if not is_conv:
|
||||
assert (
|
||||
module.weight.size(1) % block_size == 0
|
||||
), "Input features must be a multiple of block sizes"
|
||||
|
||||
# 4D matrix
|
||||
else:
|
||||
# 1x1 convolutions
|
||||
if module.kernel_size == (1, 1):
|
||||
assert (
|
||||
module.in_channels % block_size == 0
|
||||
), "Input channels must be a multiple of block sizes"
|
||||
# regular convolutions
|
||||
else:
|
||||
k = module.kernel_size[0] * module.kernel_size[1]
|
||||
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
||||
|
||||
def _forward_pre_hook(mod, input):
|
||||
# no noise for evaluation
|
||||
if mod.training:
|
||||
if not is_conv:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_features = weight.size(1)
|
||||
out_features = weight.size(0)
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
mask = torch.zeros(
|
||||
in_features // block_size * out_features, device=weight.device
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
||||
|
||||
else:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_channels = mod.in_channels
|
||||
out_channels = mod.out_channels
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
if mod.kernel_size == (1, 1):
|
||||
mask = torch.zeros(
|
||||
int(in_channels // block_size * out_channels),
|
||||
device=weight.device,
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
||||
else:
|
||||
mask = torch.zeros(
|
||||
weight.size(0), weight.size(1), device=weight.device
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = (
|
||||
mask.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
||||
)
|
||||
|
||||
# scale weights and apply mask
|
||||
mask = mask.to(
|
||||
torch.bool
|
||||
) # x.bool() is not currently supported in TorchScript
|
||||
s = 1 / (1 - p)
|
||||
mod.weight.data = s * weight.masked_fill(mask, 0)
|
||||
|
||||
module.register_forward_pre_hook(_forward_pre_hook)
|
||||
return module
|
||||
|
||||
|
||||
class FairseqDropout(nn.Module):
|
||||
def __init__(self, p, module_name=None):
|
||||
super().__init__()
|
||||
self.p = p
|
||||
self.module_name = module_name
|
||||
self.apply_during_inference = False
|
||||
|
||||
def forward(self, x, inplace: bool = False):
|
||||
if self.p > 0 and (self.training or self.apply_during_inference):
|
||||
return F.dropout(x, p=self.p, training=True, inplace=inplace)
|
||||
else:
|
||||
return x
|
||||
|
||||
def make_generation_fast_(
|
||||
self,
|
||||
name: str,
|
||||
retain_dropout: bool = False,
|
||||
retain_dropout_modules: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
):
|
||||
if retain_dropout:
|
||||
if retain_dropout_modules is not None and self.module_name is None:
|
||||
pass
|
||||
elif (
|
||||
retain_dropout_modules is None # if None, apply to all modules
|
||||
or self.module_name in retain_dropout_modules
|
||||
):
|
||||
self.apply_during_inference = True
|
||||
|
||||
|
||||
class GradMultiply(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, scale):
|
||||
ctx.scale = scale
|
||||
res = x.new(x)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return grad * ctx.scale, None
|
108
egs/librispeech/SSL/zipformer/wav2vec2_module.py
Normal file
108
egs/librispeech/SSL/zipformer/wav2vec2_module.py
Normal file
@ -0,0 +1,108 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from utils import Fp32GroupNorm, Fp32LayerNorm, TransposeLast
|
||||
|
||||
|
||||
class ConvFeatureExtractionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conv_layers: List[Tuple[int, int, int]],
|
||||
dropout: float = 0.0,
|
||||
mode: str = "default",
|
||||
conv_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert mode in {"default", "layer_norm"}
|
||||
|
||||
def block(
|
||||
n_in,
|
||||
n_out,
|
||||
k,
|
||||
stride,
|
||||
is_layer_norm=False,
|
||||
is_group_norm=False,
|
||||
conv_bias=False,
|
||||
):
|
||||
def make_conv():
|
||||
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
||||
nn.init.kaiming_normal_(conv.weight)
|
||||
return conv
|
||||
|
||||
assert (
|
||||
is_layer_norm and is_group_norm
|
||||
) == False, "layer norm and group norm are exclusive"
|
||||
|
||||
if is_layer_norm:
|
||||
return nn.Sequential(
|
||||
make_conv(),
|
||||
nn.Dropout(p=dropout),
|
||||
nn.Sequential(
|
||||
TransposeLast(),
|
||||
Fp32LayerNorm(dim, elementwise_affine=True),
|
||||
TransposeLast(),
|
||||
),
|
||||
nn.GELU(),
|
||||
)
|
||||
elif is_group_norm:
|
||||
return nn.Sequential(
|
||||
make_conv(),
|
||||
nn.Dropout(p=dropout),
|
||||
Fp32GroupNorm(dim, dim, affine=True),
|
||||
nn.GELU(),
|
||||
)
|
||||
else:
|
||||
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
||||
|
||||
in_d = 1
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for i, cl in enumerate(conv_layers):
|
||||
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
||||
(dim, k, stride) = cl
|
||||
|
||||
self.conv_layers.append(
|
||||
block(
|
||||
in_d,
|
||||
dim,
|
||||
k,
|
||||
stride,
|
||||
is_layer_norm=mode == "layer_norm",
|
||||
is_group_norm=mode == "default" and i == 0,
|
||||
conv_bias=conv_bias,
|
||||
)
|
||||
)
|
||||
in_d = dim
|
||||
|
||||
def forward(self, x):
|
||||
# BxT -> BxCxT
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
for conv in self.conv_layers:
|
||||
x = conv(x)
|
||||
|
||||
return x
|
2438
egs/librispeech/SSL/zipformer/zipformer.py
Normal file
2438
egs/librispeech/SSL/zipformer/zipformer.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user