k2SSL: a Faster and Better Framework for Self-Supervised Speech Representation Learning (#1500)

* Add k2SSL

* fix flake8

* fix for black

* fix for black

* fix for black

* Update ssl_datamodule.py

* Fix bugs in HubertDataset

* update comments

* add librilight

* add checkpoint convert script

* format

---------

Co-authored-by: yifanyeung <yifanyeung@yifanyeung.local>
Co-authored-by: zzasdf <15218404468@163.com>
This commit is contained in:
Yifan Yang 2024-04-04 23:29:16 +08:00 committed by GitHub
parent c45e9fecfb
commit 87843e9382
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
62 changed files with 24874 additions and 0 deletions

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/asr_datamodule.py

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/beam_search.py

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/dataset.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/encoder_interface.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/hubert_ce.py

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/model.py

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/optim.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/scaling.py

View File

@ -0,0 +1,334 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import logging
import re
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from dataset import HubertDataset
from lhotse import CutSet, combine, load_manifest_lazy
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriLightDataModule:
"""
DataModule for SSL experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in SSL
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
This class should be derived for specific corpora used in SSL tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR SSL related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/kmeans"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=float,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--do-normalize",
type=str2bool,
default=True,
help="whether to normalize the data",
)
group.add_argument(
"--random-crop",
type=str2bool,
default=True,
help="audio sample rate",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = HubertDataset(
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
pad_audio=pad_audio,
num_classes=num_classes,
do_normalize=do_normalize,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(
self,
cuts_valid: CutSet,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
) -> DataLoader:
logging.info("About to create dev dataset")
validate = HubertDataset(
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
pad_audio=pad_audio,
num_classes=num_classes,
do_normalize=do_normalize,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(
self,
cuts: CutSet,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
) -> DataLoader:
logging.debug("About to create test dataset")
test = HubertDataset(
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
pad_audio=pad_audio,
num_classes=num_classes,
do_normalize=do_normalize,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def small_cuts(self) -> CutSet:
logging.info("About to get small cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librilight_cuts_small.jsonl.gz"
)
@lru_cache()
def medium_cuts(self) -> CutSet:
logging.info("About to get medium cuts")
filenames = glob.glob(
f"{self.args.manifest_dir}/medium_splits/librilight_cuts_medium.*.jsonl.gz"
)
pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(
f"Loading LibriLight medium {len(sorted_filenames)} splits in lazy mode"
)
return combine(load_manifest_lazy(p) for p in sorted_filenames)
@lru_cache()
def large_cuts(self) -> CutSet:
logging.info("About to get large cuts")
filenames = glob.glob(
f"{self.args.manifest_dir}/large_splits/librilight_cuts_large.*.jsonl.gz"
)
pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(
f"Loading LibriLight large {len(sorted_filenames)} splits in lazy mode"
)
return combine(load_manifest_lazy(p) for p in sorted_filenames)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info("About to get the shuffled small, medium and large cuts")
small_cuts = self.small_cuts()
medium_cuts = self.medium_cuts()
large_cuts = self.large_cuts()
return CutSet.mux(
small_cuts,
medium_cuts,
large_cuts,
weights=[
122867, # len(small_cuts)
1104071, # len(medium_cuts)
11012085, # len(large_cuts)
],
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/utils.py

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/wav2vec2_module.py

View File

@ -0,0 +1 @@
../../../librispeech/SSL/zipformer/zipformer.py

View File

@ -0,0 +1,287 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2024 Xiaomi Corporation (Author: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from dataset import HubertAsrDataset
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriSpeechAsrDataModule:
"""
DataModule for ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled use 960h LibriSpeech. " "Otherwise, use 100h subset.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/wav"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=float,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--do-normalize",
type=str2bool,
default=True,
help="whether to normalize the data",
)
def train_dataloaders(
self,
cuts_train: CutSet,
do_normalize: bool,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = HubertAsrDataset(do_normalize=do_normalize)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet, do_normalize: bool) -> DataLoader:
logging.info("About to create dev dataset")
validate = HubertAsrDataset(do_normalize=do_normalize)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet, do_normalize: bool) -> DataLoader:
logging.debug("About to create test dataset")
test = HubertAsrDataset(do_normalize=do_normalize)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
train_clean_100_cuts = self.train_clean_100_cuts()
train_clean_360_cuts = self.train_clean_360_cuts()
train_other_500_cuts = self.train_other_500_cuts()
return CutSet.mux(
train_clean_100_cuts,
train_clean_360_cuts,
train_other_500_cuts,
weights=[
28539, # len(train_clean_100_cuts)
104014, # len(train_clean_360_cuts)
148688, # len(train_other_500_cuts)
],
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)

View File

@ -0,0 +1,840 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
import utils
from torch import Tensor, nn
from torch.nn import Parameter
from utils import FairseqDropout, quant_noise
_xformers_available = False
# TODO: move this into xformers?
# TODO: uint8 input type should just output a bool
def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None):
"""
call to pytorch multihead accepts three mask types:
- ByteTensor where non-zero means to mask
- FloatTensor which is an additive mask
- BoolTensor where True means to mask
xFormers currently accepts boolean and additive maks. For boolean masks
the values have opposite meaning. For a BoolTensor True mean to keep the value.
"""
float_types = [torch.float, torch.float16]
# If an input mask is a float it is an additive mask. Otherwise it is either uint8 or bool.
additive = mask.dtype in float_types
# If to_dype is not specified, keep same dtype as mask.
to_dtype = mask.dtype if to_dtype is None else to_dtype
to_additive = to_dtype in float_types
if additive:
if to_additive:
return mask.to(to_dtype)
mask = mask < 0
if to_additive:
# return additive mask
new_mask = torch.zeros_like(mask, dtype=to_dtype)
new_mask = new_mask.masked_fill_(mask, -float("inf"))
return new_mask
# In xFormers True is value to keep rather than value to mask
mask = ~mask.to(torch.bool)
mask = mask.to(to_dtype)
return mask
def init_bert_params(module):
"""
Initialize the weights specific to the BERT Model.
This overrides the default initializations depending on the specified arguments.
1. If normal_init_linear_weights is set then weights of linear
layer will be initialized using the normal distribution and
bais will be set to the specified value.
2. If normal_init_embed_weights is set then weights of embedding
layer will be initialized using the normal distribution.
3. If normal_init_proj_weights is set then weights of
in_project_weight for MultiHeadAttention initialized using
the normal distribution (to be validated).
"""
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
dictionary=None,
q_noise=0.0,
qn_block_size=8,
# TODO: pass in config rather than string.
# config defined in xformers.components.attention.AttentionConfig
xformers_att_config: Optional[str] = None,
xformers_blocksparse_layout: Optional[
torch.Tensor
] = None, # This should be part of the config
xformers_blocksparse_blocksize: Optional[
int
] = 16, # This should be part of the config
):
super().__init__()
self.use_xformers = False
if self.use_xformers and not _xformers_available:
raise ImportError("\n\n Please install xFormers.")
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert (
not self.self_attention or self.qkv_same_dim
), "Self-attention requires query, key and value to be of the same size"
self.k_proj = quant_noise(
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.beam_size = 1
self.reset_parameters()
self.onnx_trace = False
self.skip_embed_dim_check = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
def _get_reserve_head_index(self, num_heads_to_keep: int):
k_proj_heads_norm = []
q_proj_heads_norm = []
v_proj_heads_norm = []
for i in range(self.num_heads):
start_idx = i * self.head_dim
end_idx = (i + 1) * self.head_dim
k_proj_heads_norm.append(
torch.sum(
torch.abs(
self.k_proj.weight[
start_idx:end_idx,
]
)
).tolist()
+ torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist()
)
q_proj_heads_norm.append(
torch.sum(
torch.abs(
self.q_proj.weight[
start_idx:end_idx,
]
)
).tolist()
+ torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist()
)
v_proj_heads_norm.append(
torch.sum(
torch.abs(
self.v_proj.weight[
start_idx:end_idx,
]
)
).tolist()
+ torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist()
)
heads_norm = []
for i in range(self.num_heads):
heads_norm.append(
k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i]
)
sorted_head_index = sorted(
range(self.num_heads), key=lambda k: heads_norm[k], reverse=True
)
reserve_head_index = []
for i in range(num_heads_to_keep):
start = sorted_head_index[i] * self.head_dim
end = (sorted_head_index[i] + 1) * self.head_dim
reserve_head_index.append((start, end))
return reserve_head_index
def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]):
new_q_weight = []
new_q_bias = []
new_k_weight = []
new_k_bias = []
new_v_weight = []
new_v_bias = []
new_out_proj_weight = []
for ele in reserve_head_index:
start_idx, end_idx = ele
new_q_weight.append(
self.q_proj.weight[
start_idx:end_idx,
]
)
new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
new_k_weight.append(
self.k_proj.weight[
start_idx:end_idx,
]
)
new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
new_v_weight.append(
self.v_proj.weight[
start_idx:end_idx,
]
)
new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
new_q_weight = torch.cat(new_q_weight).detach()
new_k_weight = torch.cat(new_k_weight).detach()
new_v_weight = torch.cat(new_v_weight).detach()
new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach()
new_q_weight.requires_grad = True
new_k_weight.requires_grad = True
new_v_weight.requires_grad = True
new_out_proj_weight.requires_grad = True
new_q_bias = torch.cat(new_q_bias).detach()
new_q_bias.requires_grad = True
new_k_bias = torch.cat(new_k_bias).detach()
new_k_bias.requires_grad = True
new_v_bias = torch.cat(new_v_bias).detach()
new_v_bias.requires_grad = True
self.q_proj.weight = torch.nn.Parameter(new_q_weight)
self.q_proj.bias = torch.nn.Parameter(new_q_bias)
self.k_proj.weight = torch.nn.Parameter(new_k_weight)
self.k_proj.bias = torch.nn.Parameter(new_k_bias)
self.v_proj.weight = torch.nn.Parameter(new_v_weight)
self.v_proj.bias = torch.nn.Parameter(new_v_bias)
self.out_proj.weight = torch.nn.Parameter(new_out_proj_weight)
self.num_heads = len(reserve_head_index)
self.embed_dim = self.head_dim * self.num_heads
self.q_proj.out_features = self.embed_dim
self.k_proj.out_features = self.embed_dim
self.v_proj.out_features = self.embed_dim
def _set_skip_embed_dim_check(self):
self.skip_embed_dim_check = True
def _pad_masks(
self,
key_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor],
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
if attn_mask is not None:
shape = attn_mask.size()[:-1] + torch.Size([1])
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
if key_padding_mask is not None:
shape = key_padding_mask.size()[:-1] + torch.Size([1])
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(shape),
],
dim=-1,
)
return key_padding_mask, attn_mask
def _add_bias(
self,
k: Tensor,
v: Tensor,
key_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor],
bsz: int,
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
assert self.bias_k is not None
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
key_padding_mask, attn_mask = self._pad_masks(
key_padding_mask=key_padding_mask, attn_mask=attn_mask
)
return k, v, key_padding_mask, attn_mask
def _append_zero_attn(
self,
k: Tensor,
v: Tensor,
key_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor],
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
k = torch.cat(
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)],
dim=-2,
)
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)],
dim=-2,
)
key_padding_mask, attn_mask = self._pad_masks(
key_padding_mask=key_padding_mask, attn_mask=attn_mask
)
return k, v, key_padding_mask, attn_mask
def forward(
self,
query: Tensor,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
if not self.skip_embed_dim_check:
assert (
embed_dim == self.embed_dim
), f"query dim {embed_dim} != {self.embed_dim}"
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, _ = key.size()
if not torch.jit.is_scripting():
assert value is not None
assert src_len, key_bsz == value.shape[:2]
if (
not self.onnx_trace
and not is_tpu # don't use PyTorch version on TPUs
and incremental_state is None
and not static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and not torch.jit.is_scripting()
# The Multihead attention implemented in pytorch forces strong dimension check
# for input embedding dimention and K,Q,V projection dimension.
# Since pruning will break the dimension check and it is not easy to modify the pytorch API,
# it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check
and not self.skip_embed_dim_check
):
assert key is not None and value is not None
return F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
key_padding_mask.bool() if key_padding_mask is not None else None,
need_weights,
attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
if self.beam_size > 1 and bsz == key.size(1):
# key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
key = key.view(key.size(0), -1, self.beam_size, key.size(2))[
:, :, 0, :
]
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.view(
-1, self.beam_size, key_padding_mask.size(1)
)[:, 0, :]
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k, v, attn_mask, key_padding_mask = self._add_bias(
k, v, attn_mask, key_padding_mask, bsz
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
kv_bsz = bsz # need default value for scripting
if k is not None:
kv_bsz = k.size(1)
k = (
k.contiguous()
.view(-1, kv_bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, kv_bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
kv_bsz = _prev_key.size(0)
prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
assert kv_bsz == _prev_value.size(0)
prev_value = _prev_value.view(
kv_bsz * self.num_heads, -1, self.head_dim
)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=kv_bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(
kv_bsz, self.num_heads, -1, self.head_dim
)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == kv_bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k, v, key_padding_mask, attn_mask = self._append_zero_attn(
k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask
)
if self.encoder_decoder_attention and bsz != kv_bsz:
attn_weights = torch.einsum(
"bxhtd,bhsd->bxhts",
q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]),
k.view((kv_bsz, self.num_heads) + k.size()[1:]),
)
attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:])
else:
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [
bsz * self.num_heads,
tgt_len,
src_len,
]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
if self.onnx_trace:
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.view(
kv_bsz, -1, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v
attn_weights_float = utils.softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn: Optional[Tensor] = None
if self.encoder_decoder_attention and bsz != kv_bsz:
attn = torch.einsum(
"bxhts,bhsd->bxhtd",
attn_probs.view(
(
kv_bsz,
-1,
self.num_heads,
)
+ attn_probs.size()[1:]
),
v.view(
(
kv_bsz,
self.num_heads,
)
+ v.size()[1:]
),
)
attn = attn.reshape((-1,) + attn.size()[-2:])
else:
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [
bsz * self.num_heads,
tgt_len,
self.head_dim,
]
if self.onnx_trace and attn.size(1) == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
if src_len > prev_key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - prev_key_padding_mask.size(1)),
device=prev_key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask.float()
elif key_padding_mask is not None:
if src_len > key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - key_padding_mask.size(1)),
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = key_padding_mask.float()
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
@torch.jit.export
def reorder_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
new_order: Tensor,
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer_k = input_buffer[k]
if input_buffer_k is not None:
if self.encoder_decoder_attention:
if input_buffer_k.size(0) * self.beam_size == new_order.size(0):
return incremental_state
elif self.beam_size > 1:
input_buffer[k] = input_buffer_k.index_select(
0,
new_order.reshape(-1, self.beam_size)[:, 0]
// self.beam_size,
)
else:
input_buffer[k] = input_buffer_k.index_select(0, new_order)
else:
input_buffer[k] = input_buffer_k.index_select(0, new_order)
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
return incremental_state
def set_beam_size(self, beam_size):
"""Used for effiecient beamable enc-dec attention"""
self.beam_size = beam_size
def _get_input_buffer(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + "." if name != "" else ""
items_to_add = {}
keys_to_remove = []
for k in state_dict.keys():
if k.endswith(prefix + "in_proj_weight"):
# in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
keys_to_remove.append(k)
k_bias = prefix + "in_proj_bias"
if k_bias in state_dict.keys():
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
dim : 2 * dim
]
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
keys_to_remove.append(prefix + "in_proj_bias")
for k in keys_to_remove:
del state_dict[k]
for key, value in items_to_add.items():
state_dict[key] = value

View File

@ -0,0 +1 @@
../../ASR/zipformer/beam_search.py

View File

@ -0,0 +1,367 @@
# Copyright 2024 Xiaomi Corporation (authors: Yifan Yang)
# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Any, Dict, Optional
import numpy as np
import torch
import torch.nn.functional as F
from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.collation import read_audio_from_cuts
from torch.utils.data.dataloader import default_collate
class HubertDataset(torch.utils.data.Dataset):
"""
In this implementation, there will always be a single channel.
Returns:
.. code-block::
{
'audio': (B x NumSamples) float tensor
}
"""
def __init__(
self,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.label_rate = label_rate
self.random_crop = random_crop
self.pad_audio = pad_audio
self.num_classes = num_classes
self.normalize = do_normalize
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
self._validate(cuts)
audio, _ = read_audio_from_cuts(cuts)
for i, item in enumerate(audio):
audio[i] = self.postprocess(item, self.sample_rate)
audio_lens = [cut.num_samples for cut in cuts]
if self.pad_audio:
audio_size = min(max(audio_lens), self.max_sample_size)
else:
audio_size = min(min(audio_lens), self.max_sample_size)
audio, padding_mask, audio_starts = self.collater_audio(
audio, audio_lens, audio_size
)
kmeans = [cut.custom["kmeans"] for cut in cuts]
kmeans = [
torch.tensor([int(item) for item in label.split()], dtype=torch.int64)
for label in kmeans
]
kmeans, _ = self.collater_frm_label(kmeans, audio_size, audio_starts)
return {
"cuts": cuts,
"audio": audio,
"padding_mask": padding_mask,
"kmeans": kmeans,
}
def postprocess(self, wav, cur_sample_rate):
if wav.dim() == 2:
wav = wav.mean(-1)
assert wav.dim() == 1, wav.dim()
if cur_sample_rate != self.sample_rate:
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
if self.normalize:
with torch.no_grad():
wav = F.layer_norm(wav, wav.shape)
return wav
def _validate(self, cuts: CutSet) -> None:
validate(cuts)
assert all(cut.has_recording for cut in cuts)
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav, 0
start, end = 0, target_size
if self.random_crop:
start = np.random.randint(0, diff + 1)
end = size - diff + start
return wav[start:end], start
def collater_audio(self, audios, audio_lens, audio_size):
collated_audios = audios[0].new_zeros(len(audios), audio_size)
padding_mask = (
torch.BoolTensor(collated_audios.shape).fill_(False)
# if self.pad_audio else None
)
audio_starts = [0 for _ in audios]
for i, (audio, audio_len) in enumerate(zip(audios, audio_lens)):
audio = audio[:audio_len]
diff = audio_len - audio_size
if diff == 0:
collated_audios[i] = audio
elif diff < 0:
assert self.pad_audio
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
padding_mask[i, diff:] = True
else:
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
audio, audio_size
)
return collated_audios, padding_mask, audio_starts
def collate_tokens(
self,
values,
pad_idx,
eos_idx=None,
left_pad=False,
move_eos_to_beginning=False,
pad_to_length=None,
pad_to_multiple=1,
pad_to_bsz=None,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = size if pad_to_length is None else max(size, pad_to_length)
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
res = values[0].new(batch_size, size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
if eos_idx is None:
# if no eos_idx is specified, then use the last token in src
dst[0] = src[-1]
else:
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
return res
def collater_frm_label(self, targets, audio_size, audio_starts):
label_rate = self.label_rate
pad = self.num_classes[0] - 1
assert label_rate > 0
s2f = label_rate / self.sample_rate
frm_starts = [int(round(s * s2f)) for s in audio_starts]
frm_size = int(round(audio_size * s2f))
if not self.pad_audio:
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
frm_size = min(frm_size, *rem_size)
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
lengths = torch.LongTensor([len(t) for t in targets])
targets = self.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths
class HubertAsrDataset(torch.utils.data.Dataset):
"""
In this implementation, there will always be a single channel.
Returns:
.. code-block::
{
'audio': (B x NumSamples) float tensor
}
"""
def __init__(
self,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
random_crop: bool = True,
pad_audio: bool = True,
do_normalize: bool = True,
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.random_crop = random_crop
self.pad_audio = pad_audio
self.normalize = do_normalize
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
self._validate(cuts)
audio, _ = read_audio_from_cuts(cuts)
for i, item in enumerate(audio):
audio[i] = self.postprocess(item, self.sample_rate)
audio_lens = [cut.num_samples for cut in cuts]
if self.pad_audio:
audio_size = min(max(audio_lens), self.max_sample_size)
else:
audio_size = min(min(audio_lens), self.max_sample_size)
audio, padding_mask, audio_starts = self.collater_audio(
audio, audio_lens, audio_size
)
return {
"cuts": cuts,
"audio": audio,
"padding_mask": padding_mask,
"supervisions": default_collate(
[
{
"text": supervision.text,
}
for sequence_idx, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
def postprocess(self, wav, cur_sample_rate):
if wav.dim() == 2:
wav = wav.mean(-1)
assert wav.dim() == 1, wav.dim()
if cur_sample_rate != self.sample_rate:
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
if self.normalize:
with torch.no_grad():
wav = F.layer_norm(wav, wav.shape)
return wav
def _validate(self, cuts: CutSet) -> None:
validate(cuts)
assert all(cut.has_recording for cut in cuts)
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav, 0
start, end = 0, target_size
if self.random_crop:
start = np.random.randint(0, diff + 1)
end = size - diff + start
return wav[start:end], start
def collater_audio(self, audios, audio_lens, audio_size):
collated_audios = audios[0].new_zeros(len(audios), audio_size)
padding_mask = (
torch.BoolTensor(collated_audios.shape).fill_(False)
# if self.pad_audio else None
)
audio_starts = [0 for _ in audios]
for i, (audio, audio_len) in enumerate(zip(audios, audio_lens)):
audio = audio[:audio_len]
diff = audio_len - audio_size
if diff == 0:
collated_audios[i] = audio
elif diff < 0:
assert self.pad_audio
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
padding_mask[i, diff:] = True
else:
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
audio, audio_size
)
return collated_audios, padding_mask, audio_starts
def collate_tokens(
self,
values,
pad_idx,
eos_idx=None,
left_pad=False,
move_eos_to_beginning=False,
pad_to_length=None,
pad_to_multiple=1,
pad_to_bsz=None,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = size if pad_to_length is None else max(size, pad_to_length)
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
res = values[0].new(batch_size, size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
if eos_idx is None:
# if no eos_idx is specified, then use the last token in src
dst[0] = src[-1]
else:
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
return res
if __name__ == "__main__":
from lhotse import load_manifest_lazy
from lhotse.dataset import DynamicBucketingSampler
from torch.utils.data import DataLoader
dataset = HubertDataset()
cuts = load_manifest_lazy("data/fbank2/librispeech_cuts_train-clean-100.jsonl.gz")
sampler = DynamicBucketingSampler(
cuts,
max_duration=100,
shuffle=False,
)
dl = DataLoader(
dataset,
batch_size=None,
sampler=sampler,
num_workers=2,
)
for batch_idx, batch in enumerate(dl):
print(batch)
break

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../ASR/zipformer/decoder.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,984 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import argparse
import logging
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import GradMultiply, LayerNorm
from wav2vec2_module import ConvFeatureExtractionModel, TransformerEncoder
def compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[torch.Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
require_same_masks: bool = True,
mask_dropout: float = 0.0,
add_masks: bool = False,
seed: Optional[int] = None,
epoch: Optional[int] = None,
indices: Optional[torch.Tensor] = None,
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_type: how to compute mask lengths
static = fixed size
uniform = sample from uniform distribution [mask_other, mask_length*2]
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
poisson = sample from possion distribution with lambda = mask length
min_masks: minimum number of masked spans
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
mask_dropout: randomly dropout this percentage of masks in each example
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
if num_mask_ver == 1:
all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(bsz):
if seed is not None and epoch is not None and indices is not None:
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
else:
seed_i = None
rng = np.random.default_rng(seed_i)
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
assert sz >= 0, sz
else:
sz = all_sz
if num_mask_ver == 1:
if padding_mask is not None:
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
num_mask = all_num_mask
elif num_mask_ver == 2:
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ rng.random()
)
num_mask = max(min_masks, num_mask)
else:
raise ValueError()
if mask_type == "static":
lengths = np.full(num_mask, mask_length)
elif mask_type == "uniform":
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
elif mask_type == "normal":
lengths = rng.normal(mask_length, mask_other, size=num_mask)
lengths = [max(1, int(round(x))) for x in lengths]
elif mask_type == "poisson":
lengths = rng.poisson(mask_length, size=num_mask)
lengths = [int(round(x)) for x in lengths]
else:
raise Exception("unknown mask selection " + mask_type)
if sum(lengths) == 0:
if mask_type == "static":
raise ValueError(f"this should never happens")
else:
lengths = [min(mask_length, sz - 1)]
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = rng.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = np.fromiter(
(e - s if e - s >= length + min_space else 0 for s, e in parts),
np.int,
)
l_sum = np.sum(lens)
if l_sum == 0:
break
probs = lens / np.sum(lens)
c = rng.choice(len(parts), p=probs)
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = np.asarray(mask_idc)
else:
if idc_select_ver == 1:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
elif idc_select_ver == 2:
mask_idc = rng.choice(sz, num_mask, replace=False)
else:
raise ValueError()
mask_idc = np.asarray(
[
mask_idc[j] + offset
for j in range(len(mask_idc))
for offset in range(lengths[j])
]
)
mask_idc = np.unique(mask_idc[mask_idc < sz])
if len(mask_idc) >= sz:
raise ValueError(
(
f"the entire sequence is masked. "
f"sz={sz}; mask_idc[mask_idc]; "
f"index={indices[i] if indices is not None else None}"
)
)
mask_idcs.append(mask_idc)
target_len = None
if require_same_masks:
if add_masks:
target_len = max([len(m) for m in mask_idcs])
else:
target_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if target_len is not None and len(mask_idc) > target_len:
mask_idc = rng.choice(mask_idc, target_len, replace=False)
mask[i, mask_idc] = True
if target_len is not None and len(mask_idc) < target_len:
unmasked = np.flatnonzero(~mask[i])
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
mask[i, to_mask] = True
if mask_dropout > 0:
masked = np.flatnonzero(mask[i])
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
to_drop = rng.choice(masked, num_holes, replace=False)
mask[i, to_drop] = False
return mask
def add_hubert_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--label-rate",
type=float,
default=50,
)
parser.add_argument(
"--sample-rate",
type=float,
default=16000,
)
parser.add_argument(
"--extractor-mode",
type=str,
default="default",
help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group
norm with d groups in the first conv block, whereas layer_norm
has layer norms in every block (meant to use with normalize=True)""",
)
parser.add_argument(
"--encoder-layers",
type=int,
default=12,
help="num encoder layers in the transformer",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
default=768,
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
default=3072,
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
default=12,
help="num encoder attention heads",
)
parser.add_argument(
"--activation-fn",
type=str,
choices=[
"relu",
"gelu",
"gelu_fast",
"gelu_accurate",
"tanh",
"linear",
],
default="gelu",
help="activation function to use",
)
parser.add_argument(
"--layer-type",
type=str,
choices=["transformer", "conformer", "trf_adp"],
default="transformer",
help="layer type in encoder",
)
# dropouts
parser.add_argument(
"--dropout",
type=float,
default=0.1,
help="dropout probability for the transformer",
)
parser.add_argument(
"--attention-dropout",
type=float,
default=0.1,
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
type=float,
default=0.0,
help="dropout probability after activation in FFN",
)
parser.add_argument(
"--encoder-layerdrop",
type=float,
default=0.0,
help="probability of dropping a tarnsformer layer",
)
parser.add_argument(
"--dropout-input",
type=float,
default=0.0,
help="dropout to apply to the input (after feat extr)",
)
parser.add_argument(
"--dropout-features",
type=float,
default=0.0,
help="dropout to apply to the features (after feat extr)",
)
parser.add_argument(
"--final-dim",
type=int,
default=0,
help="project final representations and targets to this many dimensions. set to encoder_embed_dim is <= 0",
)
parser.add_argument(
"--untie-final-proj",
type=bool,
default=False,
help="use separate projection for each target",
)
parser.add_argument(
"--layer-norm-first",
type=bool,
default=False,
help="apply layernorm first in the transformer",
)
parser.add_argument(
"--conv-feature-layers",
type=str,
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]",
)
parser.add_argument(
"--conv-bias",
type=bool,
default=False,
help="include bias in conv encoder",
)
parser.add_argument(
"--logit-temp",
type=float,
default=0.1,
help="temperature to divide logits by",
)
parser.add_argument(
"--target-glu",
type=bool,
default=False,
help="adds projection + glu to targets",
)
parser.add_argument(
"--feature-grad-mult",
type=float,
default=1.0,
help="multiply feature extractor var grads by this",
)
# masking
parser.add_argument("--mask-length", type=int, default=10, help="mask_length")
parser.add_argument(
"--mask-prob",
type=float,
default=0.65,
help="probability of replacing a token with mask",
)
parser.add_argument(
"--mask-selection",
type=str,
choices=["static", "uniform", "normal", "poisson"],
default="static",
help="how to choose mask length",
)
parser.add_argument(
"--mask-other",
type=float,
default=0,
help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh",
)
parser.add_argument(
"--no-mask-overlap",
type=bool,
default=False,
help="whether to allow masks to overlap",
)
parser.add_argument(
"--mask-min-space",
type=int,
default=1,
help="min space between spans (if no overlap is enabled)",
)
# channel masking
parser.add_argument(
"--mask-channel-length",
type=int,
default=10,
help="length of the mask for features (channels)",
)
parser.add_argument(
"--mask-channel-prob",
type=float,
default=0.0,
help="probability of replacing a feature with 0",
)
parser.add_argument(
"--mask-channel-selection",
type=str,
choices=["static", "uniform", "normal", "poisson"],
default="static",
help="how to choose mask length for channel masking",
)
parser.add_argument(
"--mask-channel-other",
type=float,
default=0,
help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh",
)
parser.add_argument(
"--no-mask-channel-overlap",
type=bool,
default=False,
help="whether to allow channel masks to overlap",
)
parser.add_argument(
"--mask-channel-min-space",
type=int,
default=1,
help="min space between spans (if no overlap is enabled)",
)
# positional embeddings
parser.add_argument(
"--conv-pos",
type=int,
default=128,
help="number of filters for convolutional positional embeddings",
)
parser.add_argument(
"--conv-pos-groups",
type=int,
default=16,
help="number of groups for convolutional positional embedding",
)
parser.add_argument(
"--conv-pos-batch-norm",
type=bool,
default=False,
help="use batch norm instead of weight norm in conv_pos (for bf16 models)",
)
parser.add_argument(
"--latent-temp",
type=float,
nargs="*",
default=[2, 0.5, 0.999995],
help="legacy (to be removed)",
)
# loss computation
parser.add_argument(
"--skip-masked",
type=bool,
default=False,
help="skip computing losses over masked frames",
)
parser.add_argument(
"--skip-nomask",
type=bool,
default=False,
help="skip computing losses over unmasked frames",
)
parser.add_argument(
"--checkpoint-activations",
type=bool,
default=False,
help="recompute activations and save memory for extra compute",
)
parser.add_argument(
"--pred-masked-weight",
type=float,
default=1,
help="weight for masked part in ssl loss",
)
parser.add_argument(
"--pred-nomask-weight",
type=float,
default=0,
help="weight for masked part in ssl loss",
)
parser.add_argument(
"--loss-weights",
type=float,
nargs="*",
default=[10],
help="weight for masked part in ssl loss",
)
# FP16 optimization
parser.add_argument(
"--required-seq-len-multiple",
type=int,
default=2,
help="pad the input to encoder such that the sequence length is divisible by multiple",
)
parser.add_argument(
"--attn-type", type=str, default="", help="if espnet use ESPNET MHA"
)
parser.add_argument(
"--pos-enc-type",
type=str,
default="abs",
help="Positional encoding type to use in conformer",
)
parser.add_argument(
"--num-classes",
type=int,
nargs="*",
default=[504],
help="""num class, a little larger than the number of cluster,
the largest is for padding,
and the value should be the multiple of 4, for faster computation""",
)
class HubertModel(nn.Module):
def __init__(
self,
cfg,
) -> None:
super().__init__()
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
self.embed = feature_enc_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias,
)
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.feature_grad_mult = cfg.feature_grad_mult
self.logit_temp = cfg.logit_temp
self.skip_masked = cfg.skip_masked
self.skip_nomask = cfg.skip_nomask
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
self.mask_emb = nn.Parameter(
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
)
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
self.target_glu = None
if cfg.target_glu:
self.target_glu = nn.Sequential(
nn.Linear(final_dim, final_dim * 2), nn.GLU()
)
self.untie_final_proj = cfg.untie_final_proj
if self.untie_final_proj:
self.final_proj = nn.Linear(
cfg.encoder_embed_dim, final_dim * len(cfg.num_classes)
)
else:
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
# modules below are not needed during fine-tuning
self.num_classes = cfg.num_classes
self.label_embs_concat = nn.Parameter(
torch.FloatTensor(sum(self.num_classes), final_dim)
)
self.pred_masked_weight = cfg.pred_masked_weight
self.pred_nomask_weight = cfg.pred_nomask_weight
self.loss_weights = cfg.loss_weights
nn.init.uniform_(self.label_embs_concat)
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
super().upgrade_state_dict_named(state_dict, name)
return state_dict
def apply_mask(self, x, padding_mask, target_list):
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb.to(x.dtype)
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
return x, mask_indices
def compute_nce(self, x, pos, negs):
neg_is_pos = (pos == negs).all(-1)
pos = pos.unsqueeze(0)
targets = torch.cat([pos, negs], dim=0)
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
logits /= self.logit_temp
if neg_is_pos.any():
logits[1:][neg_is_pos] = float("-inf")
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
return logits
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.feature_extractor(source)
return features
def forward_targets(
self,
features: torch.Tensor,
target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz = features.size(2)
targ_tsz = min([t.size(1) for t in target_list])
if self.feat2tar_ratio * feat_tsz > targ_tsz:
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
features = features[..., :feat_tsz]
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
target_list = [t[:, target_inds.long()] for t in target_list]
return features, target_list
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask
def forward(
self,
source: torch.Tensor,
target_list: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
):
"""output layer is 1-based"""
features = self.forward_features(source)
if target_list is not None:
features, target_list = self.forward_targets(features, target_list)
features_pen = features.float().pow(2).mean()
features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
if mask:
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
else:
x = features
mask_indices = None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x, _ = self.encoder(
x,
padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1,
)
if features_only:
return {"x": x, "padding_mask": padding_mask, "features": features}
def compute_pred(proj_x, target, label_embs):
# compute logits for the i-th label set
y = torch.index_select(label_embs, 0, target.long())
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
if self.target_glu:
y = self.target_glu(y)
negs = self.target_glu(negs)
# proj_x: (S, D)
# y: (S, D)
# negs: (Neg, S, D)
return self.compute_nce(proj_x, y, negs)
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
if not self.skip_masked:
masked_indices = torch.logical_and(~padding_mask, mask_indices)
proj_x_m = self.final_proj(x[masked_indices])
if self.untie_final_proj:
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
else:
proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
logit_m_list = [
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
]
else:
logit_m_list = [None for _ in target_list]
if not self.skip_nomask:
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
proj_x_u = self.final_proj(x[nomask_indices])
if self.untie_final_proj:
proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
else:
proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
logit_u_list = [
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list))
]
else:
logit_u_list = [None for _ in target_list]
# result = {
# "logit_m_list": logit_m_list,
# "logit_u_list": logit_u_list,
# "padding_mask": padding_mask,
# "features_pen": features_pen,
# }
return self.compute_loss(logit_m_list, logit_u_list, features_pen)
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
res = self.forward(
source,
padding_mask=padding_mask,
mask=mask,
features_only=True,
output_layer=output_layer,
)
feature = res["features"] if ret_conv else res["x"]
return feature, res["padding_mask"]
def get_logits(self, net_output, is_masked=True):
if is_masked:
logits_list = net_output["logit_m_list"]
else:
logits_list = net_output["logit_u_list"]
logits_list = [x.float() for x in logits_list if x is not None]
return logits_list
def get_targets(self, net_output, is_masked=True):
logits_list = self.get_logits(net_output, is_masked)
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
return targets_list
def get_extra_losses(self, net_output):
extra_losses = []
names = []
if "features_pen" in net_output:
extra_losses.append(net_output["features_pen"])
names.append("features_pen")
return extra_losses, names
def remove_pretraining_modules(self):
self.target_glu = None
self.final_proj = None
def compute_loss(self, logit_m_list, logit_u_list, features_pen):
loss = 0.0
sample_size = 0
logging_output = {}
reduce = True
reduction = "sum" if reduce else "none"
loss_m_list = []
logp_m_list = [x.float() for x in logit_m_list if x is not None]
targ_m_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logp_m_list]
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
loss_m_list.append(loss_m)
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
if self.pred_masked_weight > 0:
loss += self.pred_masked_weight * sum(loss_m_list)
sample_size += targ_m_list[0].numel()
loss_u_list = []
logp_u_list = [x.float() for x in logit_u_list if x is not None]
targ_u_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logp_u_list]
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
loss_u_list.append(loss_u)
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
if self.pred_nomask_weight > 0:
loss += self.pred_nomask_weight * sum(loss_u_list)
sample_size += targ_u_list[0].numel()
if self.loss_weights is not None:
extra_losses = []
names = []
extra_losses.append(features_pen)
names.append("features_pen")
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
logging_output[f"loss_{n}"] = p.item()
logging_output = {
"loss": loss.item() if reduce else loss,
**logging_output,
}
# for lk in self.log_keys:
# if lk in net_output:
# logging_output[lk] = float((net_output[lk]))
def compute_correct(logits):
if logits.numel() == 0:
return 0, 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
with torch.no_grad():
for i, logp_m in enumerate(logp_m_list):
corr_m, count_m = compute_correct(logp_m)
logging_output[f"correct_m_{i}"] = corr_m
logging_output[f"count_m_{i}"] = count_m
for i, logp_u in enumerate(logp_u_list):
corr_u, count_u = compute_correct(logp_u)
logging_output[f"correct_u_{i}"] = corr_u
logging_output[f"count_u_{i}"] = count_u
return loss, sample_size, logging_output

View File

@ -0,0 +1,940 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import argparse
import logging
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import GradMultiply, LayerNorm
from wav2vec2_module import ConvFeatureExtractionModel, TransformerEncoder
def compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[torch.Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
require_same_masks: bool = True,
mask_dropout: float = 0.0,
add_masks: bool = False,
seed: Optional[int] = None,
epoch: Optional[int] = None,
indices: Optional[torch.Tensor] = None,
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_type: how to compute mask lengths
static = fixed size
uniform = sample from uniform distribution [mask_other, mask_length*2]
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
poisson = sample from possion distribution with lambda = mask length
min_masks: minimum number of masked spans
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
mask_dropout: randomly dropout this percentage of masks in each example
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
if num_mask_ver == 1:
all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(bsz):
if seed is not None and epoch is not None and indices is not None:
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
else:
seed_i = None
rng = np.random.default_rng(seed_i)
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
assert sz >= 0, sz
else:
sz = all_sz
if num_mask_ver == 1:
if padding_mask is not None:
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
num_mask = all_num_mask
elif num_mask_ver == 2:
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ rng.random()
)
num_mask = max(min_masks, num_mask)
else:
raise ValueError()
if mask_type == "static":
lengths = np.full(num_mask, mask_length)
elif mask_type == "uniform":
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
elif mask_type == "normal":
lengths = rng.normal(mask_length, mask_other, size=num_mask)
lengths = [max(1, int(round(x))) for x in lengths]
elif mask_type == "poisson":
lengths = rng.poisson(mask_length, size=num_mask)
lengths = [int(round(x)) for x in lengths]
else:
raise Exception("unknown mask selection " + mask_type)
if sum(lengths) == 0:
if mask_type == "static":
raise ValueError(f"this should never happens")
else:
lengths = [min(mask_length, sz - 1)]
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = rng.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = np.fromiter(
(e - s if e - s >= length + min_space else 0 for s, e in parts),
np.int,
)
l_sum = np.sum(lens)
if l_sum == 0:
break
probs = lens / np.sum(lens)
c = rng.choice(len(parts), p=probs)
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = np.asarray(mask_idc)
else:
if idc_select_ver == 1:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
elif idc_select_ver == 2:
mask_idc = rng.choice(sz, num_mask, replace=False)
else:
raise ValueError()
mask_idc = np.asarray(
[
mask_idc[j] + offset
for j in range(len(mask_idc))
for offset in range(lengths[j])
]
)
mask_idc = np.unique(mask_idc[mask_idc < sz])
if len(mask_idc) >= sz:
raise ValueError(
(
f"the entire sequence is masked. "
f"sz={sz}; mask_idc[mask_idc]; "
f"index={indices[i] if indices is not None else None}"
)
)
mask_idcs.append(mask_idc)
target_len = None
if require_same_masks:
if add_masks:
target_len = max([len(m) for m in mask_idcs])
else:
target_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if target_len is not None and len(mask_idc) > target_len:
mask_idc = rng.choice(mask_idc, target_len, replace=False)
mask[i, mask_idc] = True
if target_len is not None and len(mask_idc) < target_len:
unmasked = np.flatnonzero(~mask[i])
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
mask[i, to_mask] = True
if mask_dropout > 0:
masked = np.flatnonzero(mask[i])
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
to_drop = rng.choice(masked, num_holes, replace=False)
mask[i, to_drop] = False
return mask
def add_hubert_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--label-rate",
type=float,
default=50,
)
parser.add_argument(
"--sample-rate",
type=float,
default=16000,
)
parser.add_argument(
"--extractor-mode",
type=str,
default="default",
help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group
norm with d groups in the first conv block, whereas layer_norm
has layer norms in every block (meant to use with normalize=True)""",
)
parser.add_argument(
"--encoder-layers",
type=int,
default=12,
help="num encoder layers in the transformer",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
default=768,
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
default=3072,
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
default=12,
help="num encoder attention heads",
)
parser.add_argument(
"--activation-fn",
type=str,
choices=[
"relu",
"gelu",
"gelu_fast",
"gelu_accurate",
"tanh",
"linear",
],
default="gelu",
help="activation function to use",
)
parser.add_argument(
"--layer-type",
type=str,
choices=["transformer", "conformer", "trf_adp"],
default="transformer",
help="layer type in encoder",
)
# dropouts
parser.add_argument(
"--dropout",
type=float,
default=0.1,
help="dropout probability for the transformer",
)
parser.add_argument(
"--attention-dropout",
type=float,
default=0.1,
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
type=float,
default=0.0,
help="dropout probability after activation in FFN",
)
parser.add_argument(
"--encoder-layerdrop",
type=float,
default=0.0,
help="probability of dropping a tarnsformer layer",
)
parser.add_argument(
"--dropout-input",
type=float,
default=0.0,
help="dropout to apply to the input (after feat extr)",
)
parser.add_argument(
"--dropout-features",
type=float,
default=0.0,
help="dropout to apply to the features (after feat extr)",
)
parser.add_argument(
"--final-dim",
type=int,
default=0,
help="project final representations and targets to this many dimensions. set to encoder_embed_dim is <= 0",
)
parser.add_argument(
"--untie-final-proj",
type=bool,
default=False,
help="use separate projection for each target",
)
parser.add_argument(
"--layer-norm-first",
type=bool,
default=False,
help="apply layernorm first in the transformer",
)
parser.add_argument(
"--conv-feature-layers",
type=str,
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]",
)
parser.add_argument(
"--conv-bias",
type=bool,
default=False,
help="include bias in conv encoder",
)
parser.add_argument(
"--logit-temp",
type=float,
default=0.1,
help="temperature to divide logits by",
)
parser.add_argument(
"--target-glu",
type=bool,
default=False,
help="adds projection + glu to targets",
)
parser.add_argument(
"--feature-grad-mult",
type=float,
default=1.0,
help="multiply feature extractor var grads by this",
)
# masking
parser.add_argument("--mask-length", type=int, default=10, help="mask_length")
parser.add_argument(
"--mask-prob",
type=float,
default=0.65,
help="probability of replacing a token with mask",
)
parser.add_argument(
"--mask-selection",
type=str,
choices=["static", "uniform", "normal", "poisson"],
default="static",
help="how to choose mask length",
)
parser.add_argument(
"--mask-other",
type=float,
default=0,
help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh",
)
parser.add_argument(
"--no-mask-overlap",
type=bool,
default=False,
help="whether to allow masks to overlap",
)
parser.add_argument(
"--mask-min-space",
type=int,
default=1,
help="min space between spans (if no overlap is enabled)",
)
# channel masking
parser.add_argument(
"--mask-channel-length",
type=int,
default=10,
help="length of the mask for features (channels)",
)
parser.add_argument(
"--mask-channel-prob",
type=float,
default=0.0,
help="probability of replacing a feature with 0",
)
parser.add_argument(
"--mask-channel-selection",
type=str,
choices=["static", "uniform", "normal", "poisson"],
default="static",
help="how to choose mask length for channel masking",
)
parser.add_argument(
"--mask-channel-other",
type=float,
default=0,
help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh",
)
parser.add_argument(
"--no-mask-channel-overlap",
type=bool,
default=False,
help="whether to allow channel masks to overlap",
)
parser.add_argument(
"--mask-channel-min-space",
type=int,
default=1,
help="min space between spans (if no overlap is enabled)",
)
# positional embeddings
parser.add_argument(
"--conv-pos",
type=int,
default=128,
help="number of filters for convolutional positional embeddings",
)
parser.add_argument(
"--conv-pos-groups",
type=int,
default=16,
help="number of groups for convolutional positional embedding",
)
parser.add_argument(
"--conv-pos-batch-norm",
type=bool,
default=False,
help="use batch norm instead of weight norm in conv_pos (for bf16 models)",
)
parser.add_argument(
"--latent-temp",
type=float,
nargs="*",
default=[2, 0.5, 0.999995],
help="legacy (to be removed)",
)
# loss computation
parser.add_argument(
"--skip-masked",
type=bool,
default=False,
help="skip computing losses over masked frames",
)
parser.add_argument(
"--skip-nomask",
type=bool,
default=False,
help="skip computing losses over unmasked frames",
)
parser.add_argument(
"--checkpoint-activations",
type=bool,
default=False,
help="recompute activations and save memory for extra compute",
)
parser.add_argument(
"--pred-masked-weight",
type=float,
default=1,
help="weight for masked part in ssl loss",
)
parser.add_argument(
"--pred-nomask-weight",
type=float,
default=0,
help="weight for masked part in ssl loss",
)
parser.add_argument(
"--loss-weights",
type=float,
nargs="*",
default=[10],
help="weight for masked part in ssl loss",
)
# FP16 optimization
parser.add_argument(
"--required-seq-len-multiple",
type=int,
default=2,
help="pad the input to encoder such that the sequence length is divisible by multiple",
)
parser.add_argument(
"--attn-type", type=str, default="", help="if espnet use ESPNET MHA"
)
parser.add_argument(
"--pos-enc-type",
type=str,
default="abs",
help="Positional encoding type to use in conformer",
)
parser.add_argument(
"--num-classes",
type=int,
nargs="*",
default=[504],
help="""num class, a little larger than the number of cluster,
the largest is for padding,
and the value should be the multiple of 4, for faster computation""",
)
class HubertModel(nn.Module):
def __init__(
self,
cfg,
) -> None:
super().__init__()
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
self.embed = feature_enc_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias,
)
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.feature_grad_mult = cfg.feature_grad_mult
self.logit_temp = cfg.logit_temp
self.skip_masked = cfg.skip_masked
self.skip_nomask = cfg.skip_nomask
self.mask_emb = nn.Parameter(
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
)
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
self.untie_final_proj = cfg.untie_final_proj
self.final_proj = nn.Linear(cfg.encoder_embed_dim, sum(cfg.num_classes))
# modules below are not needed during fine-tuning
self.num_classes = cfg.num_classes
self.pred_masked_weight = cfg.pred_masked_weight
self.pred_nomask_weight = cfg.pred_nomask_weight
self.loss_weights = cfg.loss_weights
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
super().upgrade_state_dict_named(state_dict, name)
return state_dict
def apply_mask(self, x, padding_mask, target_list):
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb.to(x.dtype)
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
return x, mask_indices
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.feature_extractor(source)
return features
def forward_targets(
self,
features: torch.Tensor,
target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz = features.size(2)
targ_tsz = min([t.size(1) for t in target_list])
if self.feat2tar_ratio * feat_tsz > targ_tsz:
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
features = features[..., :feat_tsz]
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
target_list = [t[:, target_inds.long()] for t in target_list]
return features, target_list
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask
def forward(
self,
source: torch.Tensor,
target_list: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
):
"""output layer is 1-based"""
features = self.forward_features(source)
if target_list is not None:
features, target_list = self.forward_targets(features, target_list)
features_pen = features.float().pow(2).mean()
features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
if mask:
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
else:
x = features
mask_indices = None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x, _ = self.encoder(
x,
padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1,
)
if features_only:
return {"x": x, "padding_mask": padding_mask, "features": features}
if not self.skip_masked:
masked_indices = torch.logical_and(~padding_mask, mask_indices)
proj_x_m = self.final_proj(x[masked_indices])
proj_x_m /= self.logit_temp
logit_m_list = [proj_x_m for _ in range(len(target_list))]
else:
logit_m_list = [None for _ in target_list]
if not self.skip_nomask:
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
proj_x_u = self.final_proj(x[nomask_indices])
proj_x_u /= self.logit_temp
logit_u_list = [proj_x_u for _ in range(len(target_list))]
else:
logit_u_list = [None for _ in target_list]
# result = {
# "logit_m_list": logit_m_list,
# "logit_u_list": logit_u_list,
# "padding_mask": padding_mask,
# "features_pen": features_pen,
# }
targ_m_list = target_list[0][masked_indices]
targ_m_list = targ_m_list.long()
targ_m_list = [targ_m_list for _ in range(len(target_list))]
targ_u_list = target_list[0][nomask_indices]
targ_u_list = targ_u_list.long()
targ_u_list = [targ_u_list for _ in range(len(target_list))]
return self.compute_loss(
logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
)
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
res = self.forward(
source,
padding_mask=padding_mask,
mask=mask,
features_only=True,
output_layer=output_layer,
)
feature = res["features"] if ret_conv else res["x"]
return feature, res["padding_mask"]
def get_logits(self, net_output, is_masked=True):
if is_masked:
logits_list = net_output["logit_m_list"]
else:
logits_list = net_output["logit_u_list"]
logits_list = [x.float() for x in logits_list if x is not None]
return logits_list
def get_targets(self, net_output, is_masked=True):
logits_list = self.get_logits(net_output, is_masked)
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
return targets_list
def get_extra_losses(self, net_output):
extra_losses = []
names = []
if "features_pen" in net_output:
extra_losses.append(net_output["features_pen"])
names.append("features_pen")
return extra_losses, names
def remove_pretraining_modules(self):
self.final_proj = None
def compute_loss(
self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
):
loss = 0.0
sample_size = 0
logging_output = {}
reduce = True
reduction = "sum" if reduce else "none"
loss_m_list = []
logp_m_list = [x.float() for x in logit_m_list if x is not None]
logp_m_list = torch.cat(logp_m_list)
targ_m_list = torch.cat(targ_m_list)
loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction)
loss_m_list.append(loss_m)
logging_output[f"loss_m_0"] = loss_m.detach().item()
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
if self.pred_masked_weight > 0:
loss += self.pred_masked_weight * sum(loss_m_list)
sample_size += len(targ_m_list)
loss_u_list = []
logp_u_list = [x.float() for x in logit_u_list if x is not None]
logp_u_list = torch.cat(logp_u_list)
targ_u_list = torch.cat(targ_u_list)
loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction)
loss_u_list.append(loss_u)
logging_output[f"loss_u_0"] = loss_u.detach().item()
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
if self.pred_nomask_weight > 0:
loss += self.pred_nomask_weight * sum(loss_u_list)
sample_size += len(targ_u_list)
if self.loss_weights is not None:
extra_losses = []
names = []
extra_losses.append(features_pen)
names.append("features_pen")
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
logging_output[f"loss_{n}"] = p.item()
logging_output = {
"loss": loss.item() if reduce else loss,
**logging_output,
}
# for lk in self.log_keys:
# if lk in net_output:
# logging_output[lk] = float((net_output[lk]))
def compute_correct(logits, target):
if logits.numel() == 0:
return 0, 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == target
min = logits.argmin(-1) == target
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
with torch.no_grad():
corr_m, count_m = compute_correct(logp_m_list, targ_m_list)
logging_output[f"correct_m_0"] = corr_m
logging_output[f"count_m_0"] = count_m
corr_u, count_u = compute_correct(logp_u_list, targ_u_list)
logging_output[f"correct_u_0"] = corr_u
logging_output[f"count_u_0"] = count_u
return loss, sample_size, logging_output

View File

@ -0,0 +1 @@
../../ASR/zipformer/joiner.py

View File

@ -0,0 +1,344 @@
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Zengwei Yao,
# Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
from scaling import ScaledLinear
from icefall.utils import add_sos
class AsrModel(nn.Module):
def __init__(
self,
encoder,
decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None,
encoder_dim: int = 768,
decoder_dim: int = 512,
vocab_size: int = 500,
use_transducer: bool = True,
use_ctc: bool = False,
):
"""A joint CTC & Transducer ASR model.
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
Args:
encoder:
It is the transcription network in the paper. Its accepts
inputs: `x` of (N, T, encoder_dim).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
It is used when use_transducer is True.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
It is used when use_transducer is True.
use_transducer:
Whether use transducer head. Default: True.
use_ctc:
Whether use CTC head. Default: False.
"""
super().__init__()
assert (
use_transducer or use_ctc
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
self.encoder = encoder
self.use_transducer = use_transducer
if use_transducer:
# Modules for Transducer head
assert decoder is not None
assert hasattr(decoder, "blank_id")
assert joiner is not None
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_scale=0.25
)
self.simple_lm_proj = ScaledLinear(
decoder_dim, vocab_size, initial_scale=0.25
)
else:
assert decoder is None
assert joiner is None
self.use_ctc = use_ctc
if use_ctc:
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
def forward_encoder(
self,
x: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 2-D tensor of shape (N, T).
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
if padding_mask is None:
padding_mask = torch.zeros_like(x, dtype=torch.bool)
encoder_out, padding_mask = self.encoder.extract_features(
source=x,
padding_mask=padding_mask,
mask=self.encoder.training,
)
encoder_out_lens = torch.sum(~padding_mask, dim=1)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
return encoder_out, encoder_out_lens
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=encoder_out_lens,
target_lengths=target_lengths,
reduction="sum",
)
return ctc_loss
def forward_transducer(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
y: k2.RaggedTensor,
y_lens: torch.Tensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
"""
# Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return simple_loss, pruned_loss
def forward(
self,
x: torch.Tensor,
y: k2.RaggedTensor,
padding_mask: Optional[torch.Tensor] = None,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 2-D tensor of shape (N, T).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 2, x.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == y.dim0, (x.shape, y.dim0)
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
if self.use_transducer:
# Compute transducer loss
simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y.to(x.device),
y_lens=y_lens,
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
)
else:
simple_loss = torch.empty(0)
pruned_loss = torch.empty(0)
if self.use_ctc:
# Compute CTC loss
targets = y.values
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
else:
ctc_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss, encoder_out_lens

View File

@ -0,0 +1 @@
../../ASR/zipformer/optim.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../ASR/zipformer/scaling.py

View File

@ -0,0 +1,341 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from dataset import HubertDataset
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriSpeechDataModule:
"""
DataModule for SSL experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in SSL
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
This class should be derived for specific corpora used in SSL tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="SSL data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled use 960h LibriSpeech. " "Otherwise, use 100h subset.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/kmeans"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=float,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--do-normalize",
type=str2bool,
default=True,
help="whether to normalize the data",
)
group.add_argument(
"--random-crop",
type=str2bool,
default=True,
help="always crop from the beginning if false",
)
def train_dataloaders(
self,
cuts_train: CutSet,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = HubertDataset(
max_sample_size=max_sample_size,
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
pad_audio=pad_audio,
num_classes=num_classes,
do_normalize=do_normalize,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(
self,
cuts_valid: CutSet,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
) -> DataLoader:
logging.info("About to create dev dataset")
validate = HubertDataset(
max_sample_size=max_sample_size,
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
pad_audio=pad_audio,
num_classes=num_classes,
do_normalize=do_normalize,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(
self,
cuts: CutSet,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
) -> DataLoader:
logging.debug("About to create test dataset")
test = HubertDataset(
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
pad_audio=pad_audio,
num_classes=num_classes,
do_normalize=do_normalize,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
train_clean_100_cuts = self.train_clean_100_cuts()
train_clean_360_cuts = self.train_clean_360_cuts()
train_other_500_cuts = self.train_other_500_cuts()
return CutSet.mux(
train_clean_100_cuts,
train_clean_360_cuts,
train_other_500_cuts,
weights=[
28539, # len(train_clean_100_cuts)
104014, # len(train_clean_360_cuts)
148688, # len(train_other_500_cuts)
],
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)

View File

@ -0,0 +1,338 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
from typing import Callable, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
def relu_squared(x: torch.Tensor):
return F.relu(x).pow(2)
def gelu_accurate(x):
if not hasattr(gelu_accurate, "_a"):
gelu_accurate._a = math.sqrt(2 / math.pi)
return (
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
)
def is_xla_tensor(tensor):
return torch.is_tensor(tensor) and tensor.device.type == "xla"
def index_put(tensor, indices, value):
if is_xla_tensor(tensor):
for _ in range(indices.dim(), tensor.dim()):
indices = indices.unsqueeze(-1)
if indices.size(-1) < tensor.size(-1):
indices = indices.expand_as(tensor)
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
else:
tensor[indices] = value
return tensor
def pad_to_multiple(x, multiple, dim=-1, value=0):
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
if x is None:
return None, 0
tsz = x.size(dim)
m = tsz / multiple
remainder = math.ceil(m) * multiple - tsz
if m.is_integer():
return x, 0
pad_offset = (0,) * (-1 - dim) * 2
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
def gelu(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x.float()).type_as(x)
def get_activation_fn(activation: str) -> Callable:
"""Returns the activation function corresponding to `activation`"""
if activation == "relu":
return F.relu
elif activation == "relu_squared":
return relu_squared
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
return gelu_accurate
elif activation == "gelu_accurate":
return gelu_accurate
elif activation == "tanh":
return torch.tanh
elif activation == "linear":
return lambda x: x
elif activation == "swish":
return torch.nn.SiLU
else:
raise RuntimeError("--activation-fn {} not supported".format(activation))
class SamePad(nn.Module):
def __init__(self, kernel_size, causal=False):
super().__init__()
if causal:
self.remove = kernel_size - 1
else:
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
class SamePad2d(nn.Module):
def __init__(self, kernel_size):
super().__init__()
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
assert len(x.size()) == 4
if self.remove > 0:
x = x[:, :, : -self.remove, : -self.remove]
return x
class TransposeLast(nn.Module):
def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
super().__init__()
self.deconstruct_idx = deconstruct_idx
self.tranpose_dim = tranpose_dim
def forward(self, x):
if self.deconstruct_idx is not None:
x = x[self.deconstruct_idx]
return x.transpose(self.tranpose_dim, -1)
try:
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
has_fused_layernorm = True
class FusedLayerNorm(_FusedLayerNorm):
@torch.jit.unused
def forward(self, x):
if not x.is_cuda:
return super().forward(x)
else:
with torch.cuda.device(x.device):
return super().forward(x)
except ImportError:
has_fused_layernorm = False
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
if torch.jit.is_scripting() or torch.jit.is_tracing():
export = True
if not export and torch.cuda.is_available() and has_fused_layernorm:
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
class Fp32LayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.layer_norm(
input.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
class Fp32GroupNorm(nn.GroupNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.group_norm(
input.float(),
self.num_groups,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
def softmax(x, dim: int, onnx_trace: bool = False):
if onnx_trace:
return F.softmax(x.float(), dim=dim)
else:
return F.softmax(x, dim=dim, dtype=torch.float32)
def quant_noise(module, p, block_size):
"""
Wraps modules and applies quantization noise to the weights for
subsequent quantization with Iterative Product Quantization as
described in "Training with Quantization Noise for Extreme Model Compression"
Args:
- module: nn.Module
- p: amount of Quantization Noise
- block_size: size of the blocks for subsequent quantization with iPQ
Remarks:
- Module weights must have the right sizes wrt the block size
- Only Linear, Embedding and Conv2d modules are supported for the moment
- For more detail on how to quantize by blocks with convolutional weights,
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
- We implement the simplest form of noise here as stated in the paper
which consists in randomly dropping blocks
"""
# if no quantization noise, don't register hook
if p <= 0:
return module
# supported modules
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
# test whether module.weight has the right sizes wrt block_size
is_conv = module.weight.ndim == 4
# 2D matrix
if not is_conv:
assert (
module.weight.size(1) % block_size == 0
), "Input features must be a multiple of block sizes"
# 4D matrix
else:
# 1x1 convolutions
if module.kernel_size == (1, 1):
assert (
module.in_channels % block_size == 0
), "Input channels must be a multiple of block sizes"
# regular convolutions
else:
k = module.kernel_size[0] * module.kernel_size[1]
assert k % block_size == 0, "Kernel size must be a multiple of block size"
def _forward_pre_hook(mod, input):
# no noise for evaluation
if mod.training:
if not is_conv:
# gather weight and sizes
weight = mod.weight
in_features = weight.size(1)
out_features = weight.size(0)
# split weight matrix into blocks and randomly drop selected blocks
mask = torch.zeros(
in_features // block_size * out_features,
device=weight.device,
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
else:
# gather weight and sizes
weight = mod.weight
in_channels = mod.in_channels
out_channels = mod.out_channels
# split weight matrix into blocks and randomly drop selected blocks
if mod.kernel_size == (1, 1):
mask = torch.zeros(
int(in_channels // block_size * out_channels),
device=weight.device,
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
else:
mask = torch.zeros(
weight.size(0), weight.size(1), device=weight.device
)
mask.bernoulli_(p)
mask = (
mask.unsqueeze(2)
.unsqueeze(3)
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
)
# scale weights and apply mask
mask = mask.to(
torch.bool
) # x.bool() is not currently supported in TorchScript
s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0)
module.register_forward_pre_hook(_forward_pre_hook)
return module
class FairseqDropout(nn.Module):
def __init__(self, p, module_name=None):
super().__init__()
self.p = p
self.module_name = module_name
self.apply_during_inference = False
def forward(self, x, inplace: bool = False):
if self.p > 0 and (self.training or self.apply_during_inference):
return F.dropout(x, p=self.p, training=True, inplace=inplace)
else:
return x
def make_generation_fast_(
self,
name: str,
retain_dropout: bool = False,
retain_dropout_modules: Optional[List[str]] = None,
**kwargs
):
if retain_dropout:
if retain_dropout_modules is not None and self.module_name is None:
pass
elif (
retain_dropout_modules is None # if None, apply to all modules
or self.module_name in retain_dropout_modules
):
self.apply_during_inference = True
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None

View File

@ -0,0 +1,593 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from attention_module import MultiheadAttention, init_bert_params
from utils import (
Fp32GroupNorm,
Fp32LayerNorm,
LayerNorm,
SamePad,
TransposeLast,
get_activation_fn,
index_put,
pad_to_multiple,
)
class ConvFeatureExtractionModel(nn.Module):
def __init__(
self,
conv_layers: List[Tuple[int, int, int]],
dropout: float = 0.0,
mode: str = "default",
conv_bias: bool = False,
):
super().__init__()
assert mode in {"default", "layer_norm"}
def block(
n_in,
n_out,
k,
stride,
is_layer_norm=False,
is_group_norm=False,
conv_bias=False,
):
def make_conv():
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
nn.init.kaiming_normal_(conv.weight)
return conv
assert (
is_layer_norm and is_group_norm
) == False, "layer norm and group norm are exclusive"
if is_layer_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
nn.Sequential(
TransposeLast(),
Fp32LayerNorm(dim, elementwise_affine=True),
TransposeLast(),
),
nn.GELU(),
)
elif is_group_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
Fp32GroupNorm(dim, dim, affine=True),
nn.GELU(),
)
else:
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
in_d = 1
self.conv_layers = nn.ModuleList()
for i, cl in enumerate(conv_layers):
assert len(cl) == 3, "invalid conv definition: " + str(cl)
(dim, k, stride) = cl
self.conv_layers.append(
block(
in_d,
dim,
k,
stride,
is_layer_norm=mode == "layer_norm",
is_group_norm=mode == "default" and i == 0,
conv_bias=conv_bias,
)
)
in_d = dim
def forward(self, x):
# BxT -> BxCxT
x = x.unsqueeze(1)
for conv in self.conv_layers:
x = conv(x)
return x
def make_conv_pos(e, k, g, is_batch_norm=False):
pos_conv = nn.Conv1d(
e,
e,
kernel_size=k,
padding=k // 2,
groups=g,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
nn.init.normal_(pos_conv.weight, mean=0, std=std)
nn.init.constant_(pos_conv.bias, 0)
if not is_batch_norm:
pos_conv = nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2)
pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU())
else:
batch_norm = nn.BatchNorm1d(e)
pos_conv = nn.Sequential(batch_norm, pos_conv, SamePad(k), nn.GELU())
return pos_conv
class TransformerEncoder(nn.Module):
def build_encoder_layer(self, args, **kwargs):
if args.layer_type == "transformer":
layer = TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
)
elif args.layer_type == "trf_adp":
use_adp = False
if args.adp_trf_idx == "all":
use_adp = True
else:
adp_trf_idx = list(
range(*[int(g) for g in args.adp_trf_idx.split(":")])
)
if kwargs.get("layer_idx", None) in adp_trf_idx:
use_adp = True
if use_adp:
layer = TransformerSentenceEncoderWithAdapterLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
adapter_num=args.adp_num,
adapter_dim=args.adp_dim,
adapter_act_fn=args.adp_act_fn,
)
else:
layer = TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
)
# layer = fsdp_wrap(layer)
# if args.checkpoint_activations:
# layer = checkpoint_wrapper(layer)
return layer
def __init__(self, args):
super().__init__()
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.required_seq_len_multiple = args.required_seq_len_multiple
pos_conv_depth = getattr(args, "pos_conv_depth", 1)
if pos_conv_depth > 1:
num_layers = args.pos_conv_depth
k = max(3, args.conv_pos // num_layers)
def make_conv_block(e, k, g, l):
return nn.Sequential(
*[
nn.Sequential(
nn.Conv1d(
e,
e,
kernel_size=k,
padding=k // 2,
groups=g,
),
SamePad(k),
TransposeLast(),
LayerNorm(e, elementwise_affine=False),
TransposeLast(),
nn.GELU(),
)
for _ in range(l)
]
)
self.pos_conv = make_conv_block(
self.embedding_dim, k, args.conv_pos_groups, num_layers
)
else:
self.pos_conv = make_conv_pos(
self.embedding_dim,
args.conv_pos,
args.conv_pos_groups,
is_batch_norm=args.conv_pos_batch_norm
if hasattr(args, "conv_pos_batch_norm")
else False,
)
self.layers = nn.ModuleList(
[
self.build_encoder_layer(args, layer_idx=ii)
for ii in range(args.encoder_layers)
]
)
self.layer_norm_first = args.layer_norm_first
self.layer_norm = LayerNorm(self.embedding_dim)
self.layerdrop = args.encoder_layerdrop
self.apply(init_bert_params)
def forward(self, x, padding_mask=None, layer=None, corpus_key=None):
x, layer_results = self.extract_features(
x, padding_mask, layer, corpus_key=corpus_key
)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features(
self,
x,
padding_mask=None,
tgt_layer=None,
min_layer=0,
corpus_key=None,
):
if padding_mask is not None:
x = index_put(x, padding_mask, 0)
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
# pad to the sequence length dimension
x, pad_length = pad_to_multiple(
x, self.required_seq_len_multiple, dim=-2, value=0
)
if pad_length > 0 and padding_mask is None:
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
padding_mask[:, -pad_length:] = True
else:
padding_mask, _ = pad_to_multiple(
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
layer_results = []
r = None
for i, layer in enumerate(self.layers):
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
if not self.training or (dropout_probability > self.layerdrop):
layer_check = layer
# if isinstance(layer, FullyShardedDataParallel):
# layer_check = layer.unwrapped_module
if (corpus_key is None) or (
not isinstance(
layer_check,
(TransformerSentenceEncoderWithAdapterLayer,),
)
):
x, (z, lr) = layer(
x,
self_attn_padding_mask=padding_mask,
need_weights=False,
)
else:
x, (z, lr) = layer(
x,
self_attn_padding_mask=padding_mask,
need_weights=False,
corpus_key=corpus_key,
)
if i >= min_layer:
layer_results.append((x, z, lr))
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# undo paddding
if pad_length > 0:
x = x[:, :-pad_length]
def undo_pad(a, b, c):
return (
a[:-pad_length],
b[:-pad_length] if b is not None else b,
c[:-pad_length],
)
layer_results = [undo_pad(*u) for u in layer_results]
return x, layer_results
def max_positions(self):
"""Maximum output length supported by the encoder."""
return self.args.max_positions
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
return state_dict
class TransformerSentenceEncoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: int = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
) -> None:
super().__init__()
# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
# Initialize blocks
self.activation_fn = get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
# layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = LayerNorm(self.embedding_dim)
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
att_args=None,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
attn_mask=self_attn_mask,
need_weights=False,
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
layer_result = x
x = self.dropout3(x)
x = residual + x
else:
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
)
x = self.dropout1(x)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
layer_result = x
x = self.dropout3(x)
x = residual + x
x = self.final_layer_norm(x)
return x, (attn, layer_result)
class AdapterFast(nn.Module):
def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
"""
Implements adapter modules directly with 3D tensor weight as parameters
and without using ModuleList orto speed up training throughput.
"""
super().__init__()
self.adapter_num = adapter_num
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim))
self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim))
self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim))
self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim))
self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim))
self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim))
self.act_fn = nn.Identity()
if act_fn == "relu":
self.act_fn = nn.ReLU()
elif act_fn == "gelu":
self.act_fn = nn.GELU()
elif act_fn == "selu":
self.act_fn = nn.SELU()
else:
raise ValueError(f"unsupported {act_fn}")
self.input_dim = input_dim
self.reset_parameters()
def reset_parameters(self):
for ii in range(self.adapter_num):
nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii])
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.b_a[ii], -bound, bound)
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.b_b[ii], -bound, bound)
nn.init.ones_(self.ln_W)
nn.init.zeros_(self.ln_b)
def forward(self, x, adapter_id):
ii = adapter_id
h = x
h = F.layer_norm(h, (self.input_dim,), self.ln_W[ii], self.ln_b[ii])
h = F.linear(h, self.W_a[ii], self.b_a[ii])
h = self.act_fn(h)
h = F.linear(h, self.W_b[ii], self.b_b[ii])
outputs = h
return outputs
def extra_repr(self):
return "adapter={}, input_dim={}, hidden_dim={}".format(
self.adapter_num, self.input_dim, self.hidden_dim
)
class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer):
"""
Implements a Transformer Encoder Layer with adapters used in BERT/XLM style pre-trained
models. An adapter module is added along with vanilla Transformer module.
"""
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: int = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
adapter_num=201,
adapter_dim=64,
adapter_act_fn="relu",
) -> None:
super().__init__(
embedding_dim=embedding_dim,
ffn_embedding_dim=ffn_embedding_dim,
num_attention_heads=num_attention_heads,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
activation_fn=activation_fn,
layer_norm_first=layer_norm_first,
)
self.adapter_num = adapter_num
self.adapter_dim = adapter_dim
self.adapter_layer = AdapterFast(
adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn
)
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
att_args=None,
corpus_key=None,
):
x, (attn, layer_result) = super().forward(
x=x,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_weights=need_weights,
att_args=att_args,
)
assert corpus_key is not None
assert len(set(corpus_key)) == 1, f"corpus_key items are not same {corpus_key}"
y = self.adapter_layer(x, corpus_key[0])
x = x + y
return x, (attn, layer_result)

View File

@ -0,0 +1,52 @@
import os
import jsonlines
from tqdm import tqdm
os.system(
"cp /userhome/user/yfy62/librispeech_data/data4ssl/manifests/librispeech_*_dev-clean* ."
)
os.system(
"cp /userhome/user/yfy62/librispeech_data/data4ssl/manifests/librispeech_*_train* ."
)
os.system("chmod -R 644 *.jsonl.gz")
os.system("gunzip *.gz")
dataset_parts = (
"dev-clean",
"train-clean-100",
"train-clean-360",
"train-other-500",
)
kmeans_dir = "/userhome/user/yangguanrou/data/k500"
idx_dir = "/userhome/user/yangguanrou/data/shu"
kmeans = []
idxs = []
for part in ["train", "valid"]:
with open(kmeans_dir + "/" + part + ".km", "r") as f:
kmeans += f.read().splitlines()
with open(idx_dir + "/" + part + ".tsv", "r") as f:
lines = f.read().splitlines()
idxs += [
line.split("\t", -1)[0].split("/", -1)[-1].replace(".flac", "")
for line in lines
if ".flac" in line
]
idx2kmeans = {}
for idx, km in zip(idxs, kmeans):
idx2kmeans[idx] = km
for part in dataset_parts:
with jsonlines.open(f"librispeech_supervisions_{part}.jsonl") as reader:
with jsonlines.open(
f"librispeech_supervisions_{part}_new.jsonl", mode="w"
) as writer:
for obj in tqdm(reader):
obj["custom"] = {"kmeans": idx2kmeans[obj["id"]]}
writer.write(obj)
os.system('for file in *_new.jsonl; do mv "$file" "${file%_new.jsonl}.jsonl"; done')

View File

@ -0,0 +1,18 @@
# simple script to convert a fairseq checkpoint into pytorch parameter state dict
from argparse import ArgumentParser
from collections import OrderedDict
import torch
parser = ArgumentParser()
parser.add_argument("--src")
parser.add_argument("--tgt")
args = parser.parse_args()
src = args.src
tgt = args.tgt
old_checkpoint = torch.load(src)
new_checkpoint = OrderedDict()
new_checkpoint["model"] = old_checkpoint["model"]
torch.save(new_checkpoint, tgt)

View File

@ -0,0 +1,259 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input `lang_dir`, which should contain::
- lang_dir/text,
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
- lexicon.txt
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- tokens.txt
"""
import argparse
import re
from pathlib import Path
from typing import Dict, List
import k2
import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
write_lexicon,
write_mapping,
)
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [token2id[i] if i in token2id else token2id["<unk>"] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
"""Check if all the given tokens are in token symbol table.
Args:
token_sym_table:
Token symbol table that contains all the valid tokens.
tokens:
A list of tokens.
Returns:
Return True if there is any token not in the token_sym_table,
otherwise False.
"""
for tok in tokens:
if tok not in token_sym_table:
return True
return False
def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:
token_sym_table:
Token symbol table that mapping token to token ids.
words:
A list of strings representing words.
Returns:
Return a dict whose keys are words and values are the corresponding
tokens.
"""
lexicon = []
for word in words:
chars = list(word.strip(" \t"))
if contain_oov(token_sym_table, chars):
continue
lexicon.append((word, chars))
# The OOV word is <UNK>
lexicon.append(("<UNK>", ["<unk>"]))
return lexicon
def generate_tokens(text_file: str) -> Dict[str, int]:
"""Generate tokens from the given text file.
Args:
text_file:
A file that contains text lines to generate tokens.
Returns:
Return a dict whose keys are tokens and values are token ids ranged
from 0 to len(keys) - 1.
"""
tokens: Dict[str, int] = dict()
tokens["<blk>"] = 0
tokens["<sos/eos>"] = 1
tokens["<unk>"] = 2
whitespace = re.compile(r"([ \t\r\n]+)")
with open(text_file, "r", encoding="utf-8") as f:
for line in f:
line = re.sub(whitespace, "", line)
chars = list(line)
for char in chars:
if char not in tokens:
tokens[char] = len(tokens)
return tokens
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the bpe.model and words.txt
""",
)
return parser.parse_args()
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
text_file = lang_dir / "text"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
for w in excluded:
if w in words:
words.remove(w)
token_sym_table = generate_tokens(text_file)
lexicon = generate_lexicon(token_sym_table, words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
next_token_id = max(token_sym_table.values()) + 1
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token_sym_table
token_sym_table[disambig] = next_token_id
next_token_id += 1
word_sym_table.add("#0")
word_sym_table.add("<s>")
word_sym_table.add("</s>")
write_mapping(lang_dir / "tokens.txt", token_sym_table)
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
lexicon_disambig,
token2id=token_sym_table,
word2id=word_sym_table,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,388 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
2. Generate tokens.txt, the token table mapping a token to a unique integer.
3. Generate words.txt, the word table mapping a word to a unique integer.
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.
"""
import argparse
import math
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import k2
import torch
from icefall.lexicon import read_lexicon, write_lexicon
Lexicon = List[Tuple[str, List[str]]]
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_tokens(lexicon: Lexicon) -> List[str]:
"""Get tokens from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique tokens.
"""
ans = set()
for _, tokens in lexicon:
ans.update(tokens)
sorted_ans = sorted(list(ans))
return sorted_ans
def get_words(lexicon: Lexicon) -> List[str]:
"""Get words from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique words.
"""
ans = set()
for word, _ in lexicon:
ans.add(word)
sorted_ans = sorted(list(ans))
return sorted_ans
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
at the ends of tokens to ensure that all pronunciations are different,
and that none is a prefix of another.
See also add_lex_disambig.pl from kaldi.
Args:
lexicon:
It is returned by :func:`read_lexicon`.
Returns:
Return a tuple with two elements:
- The output lexicon with disambiguation symbols
- The ID of the max disambiguation symbol that appears
in the lexicon
"""
# (1) Work out the count of each token-sequence in the
# lexicon.
count = defaultdict(int)
for _, tokens in lexicon:
count[" ".join(tokens)] += 1
# (2) For each left sub-sequence of each token-sequence, note down
# that it exists (for identifying prefixes of longer strings).
issubseq = defaultdict(int)
for _, tokens in lexicon:
tokens = tokens.copy()
tokens.pop()
while tokens:
issubseq[" ".join(tokens)] = 1
tokens.pop()
# (3) For each entry in the lexicon:
# if the token sequence is unique and is not a
# prefix of another word, no disambig symbol.
# Else output #1, or #2, #3, ... if the same token-seq
# has already been assigned a disambig symbol.
ans = []
# We start with #1 since #0 has its own purpose
first_allowed_disambig = 1
max_disambig = first_allowed_disambig - 1
last_used_disambig_symbol_of = defaultdict(int)
for word, tokens in lexicon:
tokenseq = " ".join(tokens)
assert tokenseq != ""
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
ans.append((word, tokens))
continue
cur_disambig = last_used_disambig_symbol_of[tokenseq]
if cur_disambig == 0:
cur_disambig = first_allowed_disambig
else:
cur_disambig += 1
if cur_disambig > max_disambig:
max_disambig = cur_disambig
last_used_disambig_symbol_of[tokenseq] = cur_disambig
tokenseq += f" #{cur_disambig}"
ans.append((word, tokenseq.split()))
return ans, max_disambig
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
"""Generate ID maps, i.e., map a symbol to a unique ID.
Args:
symbols:
A list of unique symbols.
Returns:
A dict containing the mapping between symbols and IDs.
"""
return {sym: i for i, sym in enumerate(symbols)}
def add_self_loops(
arcs: List[List[Any]], disambig_token: int, disambig_word: int
) -> List[List[Any]]:
"""Adds self-loops to states of an FST to propagate disambiguation symbols
through it. They are added on each state with non-epsilon output symbols
on at least one arc out of the state.
See also fstaddselfloops.pl from Kaldi. One difference is that
Kaldi uses OpenFst style FSTs and it has multiple final states.
This function uses k2 style FSTs and it does not need to add self-loops
to the final state.
The input label of a self-loop is `disambig_token`, while the output
label is `disambig_word`.
Args:
arcs:
A list-of-list. The sublist contains
`[src_state, dest_state, label, aux_label, score]`
disambig_token:
It is the token ID of the symbol `#0`.
disambig_word:
It is the word ID of the symbol `#0`.
Return:
Return new `arcs` containing self-loops.
"""
states_needs_self_loops = set()
for arc in arcs:
src, dst, ilabel, olabel, score = arc
if olabel != 0:
states_needs_self_loops.add(src)
ans = []
for s in states_needs_self_loops:
ans.append([s, s, disambig_token, disambig_word, 0])
return arcs + ans
def lexicon_to_fst(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
sil_token: str = "SIL",
sil_prob: float = 0.5,
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format) with optional silence at
the beginning and end of each word.
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
sil_token:
The silence token.
sil_prob:
The probability for adding a silence at the beginning and end
of the word.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
assert sil_prob > 0.0 and sil_prob < 1.0
# CAUTION: we use score, i.e, negative cost.
sil_score = math.log(sil_prob)
no_sil_score = math.log(1.0 - sil_prob)
start_state = 0
loop_state = 1 # words enter and leave from here
sil_state = 2 # words terminate here when followed by silence; this state
# has a silence transition to loop_state.
next_state = 3 # the next un-allocated state, will be incremented as we go.
arcs = []
assert token2id["<eps>"] == 0
assert word2id["<eps>"] == 0
eps = 0
sil_token = token2id[sil_token]
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
arcs.append([start_state, sil_state, eps, eps, sil_score])
arcs.append([sil_state, loop_state, sil_token, eps, 0])
for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
tokens = [token2id[i] for i in tokens]
for i in range(len(tokens) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, tokens[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last token of this word
# It has two out-going arcs, one to the loop state,
# the other one to the sil_state.
i = len(tokens) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
return parser.parse_args()
def main():
out_dir = Path(get_args().lang_dir)
lexicon_filename = out_dir / "lexicon.txt"
sil_token = "SIL"
sil_prob = 0.5
lexicon = read_lexicon(lexicon_filename)
tokens = get_tokens(lexicon)
words = get_words(lexicon)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in tokens
tokens.append(f"#{i}")
assert "<eps>" not in tokens
tokens = ["<eps>"] + tokens
assert "<eps>" not in words
assert "#0" not in words
assert "<s>" not in words
assert "</s>" not in words
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
token2id = generate_id_map(tokens)
word2id = generate_id_map(words)
write_mapping(out_dir / "tokens.txt", token2id)
write_mapping(out_dir / "words.txt", word2id)
write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst(
lexicon,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
)
L_disambig = lexicon_to_fst(
lexicon_disambig,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True,
)
torch.save(L.as_dict(), out_dir / "L.pt")
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(out_dir / "L.png", title="L")
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,107 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from lhotse import CutSet
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
return parser.parse_args()
def process_wav_librispeech(
dataset: Optional[str] = None,
):
src_dir = Path("data/manifests")
output_dir = Path("data/wav")
if dataset is None:
dataset_parts = (
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
)
else:
dataset_parts = dataset.split(" ", -1)
prefix = "librispeech"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
process_wav_librispeech(
dataset=args.dataset,
)

View File

@ -0,0 +1,104 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from lhotse import CutSet
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
return parser.parse_args()
def process_kmeans_librispeech(
dataset: Optional[str] = None,
):
src_dir = Path(".")
output_dir = Path(".")
if dataset is None:
dataset_parts = (
"dev-clean",
"train-clean-100",
"train-clean-360",
"train-other-500",
)
else:
dataset_parts = dataset.split(" ", -1)
prefix = "librispeech"
suffix = "jsonl"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}_raw.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
process_kmeans_librispeech(
dataset=args.dataset,
)

View File

@ -0,0 +1,23 @@
import os
import jsonlines
from tqdm import tqdm
dataset_parts = (
"dev-clean",
"train-clean-100",
"train-clean-360",
"train-other-500",
)
for part in dataset_parts:
with jsonlines.open(f"librispeech_cuts_{part}_raw.jsonl") as reader:
with jsonlines.open(f"librispeech_cuts_{part}.jsonl", mode="w") as writer:
for obj in tqdm(reader):
obj["custom"] = {"kmeans": obj["supervisions"][0]["custom"]["kmeans"]}
del obj["supervisions"][0]["custom"]
writer.write(obj)
os.system("rm *_raw.jsonl")
os.system("gzip *.jsonl")

1
egs/librispeech/SSL/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -0,0 +1 @@
../hubert/asr_datamodule.py

View File

@ -0,0 +1 @@
../../ASR/zipformer/beam_search.py

View File

@ -0,0 +1 @@
../hubert/dataset.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../ASR/zipformer/decoder.py

View File

@ -0,0 +1 @@
../../ASR/zipformer/encoder_interface.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,601 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import argparse
import logging
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import ScheduledFloat
from utils import GradMultiply, LayerNorm
from wav2vec2_module import ConvFeatureExtractionModel
from zipformer import Zipformer2
def compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[torch.Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
require_same_masks: bool = True,
mask_dropout: float = 0.0,
add_masks: bool = False,
seed: Optional[int] = None,
epoch: Optional[int] = None,
indices: Optional[torch.Tensor] = None,
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_type: how to compute mask lengths
static = fixed size
uniform = sample from uniform distribution [mask_other, mask_length*2]
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
poisson = sample from possion distribution with lambda = mask length
min_masks: minimum number of masked spans
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
mask_dropout: randomly dropout this percentage of masks in each example
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
if num_mask_ver == 1:
all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(bsz):
if seed is not None and epoch is not None and indices is not None:
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
else:
seed_i = None
rng = np.random.default_rng(seed_i)
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
assert sz >= 0, sz
else:
sz = all_sz
if num_mask_ver == 1:
if padding_mask is not None:
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
num_mask = all_num_mask
elif num_mask_ver == 2:
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ rng.random()
)
num_mask = max(min_masks, num_mask)
else:
raise ValueError()
if mask_type == "static":
lengths = np.full(num_mask, mask_length)
elif mask_type == "uniform":
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
elif mask_type == "normal":
lengths = rng.normal(mask_length, mask_other, size=num_mask)
lengths = [max(1, int(round(x))) for x in lengths]
elif mask_type == "poisson":
lengths = rng.poisson(mask_length, size=num_mask)
lengths = [int(round(x)) for x in lengths]
else:
raise Exception("unknown mask selection " + mask_type)
if sum(lengths) == 0:
if mask_type == "static":
raise ValueError(f"this should never happens")
else:
lengths = [min(mask_length, sz - 1)]
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = rng.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = np.fromiter(
(e - s if e - s >= length + min_space else 0 for s, e in parts),
np.int,
)
l_sum = np.sum(lens)
if l_sum == 0:
break
probs = lens / np.sum(lens)
c = rng.choice(len(parts), p=probs)
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = np.asarray(mask_idc)
else:
if idc_select_ver == 1:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
elif idc_select_ver == 2:
mask_idc = rng.choice(sz, num_mask, replace=False)
else:
raise ValueError()
mask_idc = np.asarray(
[
mask_idc[j] + offset
for j in range(len(mask_idc))
for offset in range(lengths[j])
]
)
mask_idc = np.unique(mask_idc[mask_idc < sz])
if len(mask_idc) >= sz:
raise ValueError(
(
f"the entire sequence is masked. "
f"sz={sz}; mask_idc[mask_idc]; "
f"index={indices[i] if indices is not None else None}"
)
)
mask_idcs.append(mask_idc)
target_len = None
if require_same_masks:
if add_masks:
target_len = max([len(m) for m in mask_idcs])
else:
target_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if target_len is not None and len(mask_idc) > target_len:
mask_idc = rng.choice(mask_idc, target_len, replace=False)
mask[i, mask_idc] = True
if target_len is not None and len(mask_idc) < target_len:
unmasked = np.flatnonzero(~mask[i])
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
mask[i, to_mask] = True
if mask_dropout > 0:
masked = np.flatnonzero(mask[i])
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
to_drop = rng.choice(masked, num_holes, replace=False)
mask[i, to_drop] = False
return mask
def _to_int_tuple(s: str):
return tuple(map(int, s.split(",")))
class HubertModel(nn.Module):
def __init__(
self,
cfg,
) -> None:
super().__init__()
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
self.embed = feature_enc_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias,
)
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate
encoder_input_dim = _to_int_tuple(cfg.encoder_dim)[0]
encoder_output_dim = max(_to_int_tuple(cfg.encoder_dim))
self.post_extract_proj = (
nn.Linear(self.embed, encoder_input_dim)
if self.embed != encoder_input_dim
else None
)
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.feature_grad_mult = cfg.feature_grad_mult
self.logit_temp = cfg.logit_temp
self.skip_masked = cfg.skip_masked
self.skip_nomask = cfg.skip_nomask
self.mask_emb = nn.Parameter(torch.FloatTensor(encoder_input_dim).uniform_())
self.encoder = Zipformer2(
output_downsampling_factor=1,
downsampling_factor=_to_int_tuple(cfg.downsampling_factor),
num_encoder_layers=_to_int_tuple(cfg.num_encoder_layers),
encoder_dim=_to_int_tuple(cfg.encoder_dim),
encoder_unmasked_dim=_to_int_tuple(cfg.encoder_unmasked_dim),
query_head_dim=_to_int_tuple(cfg.query_head_dim),
pos_head_dim=_to_int_tuple(cfg.pos_head_dim),
value_head_dim=_to_int_tuple(cfg.value_head_dim),
pos_dim=cfg.pos_dim,
num_heads=_to_int_tuple(cfg.num_heads),
feedforward_dim=_to_int_tuple(cfg.feedforward_dim),
cnn_module_kernel=_to_int_tuple(cfg.cnn_module_kernel),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
)
self.layer_norm = LayerNorm(self.embed)
self.untie_final_proj = cfg.untie_final_proj
self.final_proj = nn.Linear(encoder_output_dim, sum(cfg.num_classes))
# modules below are not needed during fine-tuning
self.num_classes = cfg.num_classes
self.pred_masked_weight = cfg.pred_masked_weight
self.pred_nomask_weight = cfg.pred_nomask_weight
self.loss_weights = cfg.loss_weights
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
super().upgrade_state_dict_named(state_dict, name)
return state_dict
def apply_mask(self, x, padding_mask, target_list):
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb.to(x.dtype)
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
return x, mask_indices
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.feature_extractor(source)
return features
def forward_targets(
self,
features: torch.Tensor,
target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz = features.size(2)
targ_tsz = min([t.size(1) for t in target_list])
if self.feat2tar_ratio * feat_tsz > targ_tsz:
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
features = features[..., :feat_tsz]
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
target_list = [t[:, target_inds.long()] for t in target_list]
return features, target_list
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask
def forward(
self,
source: torch.Tensor,
target_list: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
):
"""output layer is 1-based"""
features = self.forward_features(source)
if target_list is not None:
features, target_list = self.forward_targets(features, target_list)
features_pen = features.float().pow(2).mean()
features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
if mask:
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
else:
x = features
mask_indices = None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float -> (T, B, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x = x.transpose(0, 1)
x, x_lens = self.encoder(x, ~padding_mask.sum(dim=-1))
x = x.transpose(0, 1)
if features_only:
return {"x": x, "padding_mask": padding_mask, "features": features}
if not self.skip_masked:
masked_indices = torch.logical_and(~padding_mask, mask_indices)
proj_x_m = self.final_proj(x[masked_indices])
proj_x_m /= self.logit_temp
logit_m_list = [proj_x_m for _ in range(len(target_list))]
else:
logit_m_list = [None for _ in target_list]
if not self.skip_nomask:
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
proj_x_u = self.final_proj(x[nomask_indices])
proj_x_u /= self.logit_temp
logit_u_list = [proj_x_u for _ in range(len(target_list))]
else:
logit_u_list = [None for _ in target_list]
# result = {
# "logit_m_list": logit_m_list,
# "logit_u_list": logit_u_list,
# "padding_mask": padding_mask,
# "features_pen": features_pen,
# }
targ_m_list = target_list[0][masked_indices]
targ_m_list = targ_m_list.long()
targ_m_list = [targ_m_list for _ in range(len(target_list))]
targ_u_list = target_list[0][nomask_indices]
targ_u_list = targ_u_list.long()
targ_u_list = [targ_u_list for _ in range(len(target_list))]
return self.compute_loss(
logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
)
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
res = self.forward(
source,
padding_mask=padding_mask,
mask=mask,
features_only=True,
output_layer=output_layer,
)
feature = res["features"] if ret_conv else res["x"]
return feature, res["padding_mask"]
def get_logits(self, net_output, is_masked=True):
if is_masked:
logits_list = net_output["logit_m_list"]
else:
logits_list = net_output["logit_u_list"]
logits_list = [x.float() for x in logits_list if x is not None]
return logits_list
def get_targets(self, net_output, is_masked=True):
logits_list = self.get_logits(net_output, is_masked)
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
return targets_list
def get_extra_losses(self, net_output):
extra_losses = []
names = []
if "features_pen" in net_output:
extra_losses.append(net_output["features_pen"])
names.append("features_pen")
return extra_losses, names
def remove_pretraining_modules(self):
self.final_proj = None
def compute_loss(
self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
):
loss = 0.0
sample_size = 0
logging_output = {}
reduce = True
reduction = "sum" if reduce else "none"
loss_m_list = []
logp_m_list = [x.float() for x in logit_m_list if x is not None]
logp_m_list = torch.cat(logp_m_list)
targ_m_list = torch.cat(targ_m_list)
loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction)
loss_m_list.append(loss_m)
logging_output[f"loss_m_0"] = loss_m.detach().item()
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
if self.pred_masked_weight > 0:
loss += self.pred_masked_weight * sum(loss_m_list)
sample_size += len(targ_m_list)
loss_u_list = []
logp_u_list = [x.float() for x in logit_u_list if x is not None]
logp_u_list = torch.cat(logp_u_list)
targ_u_list = torch.cat(targ_u_list)
loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction)
loss_u_list.append(loss_u)
logging_output[f"loss_u_0"] = loss_u.detach().item()
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
if self.pred_nomask_weight > 0:
loss += self.pred_nomask_weight * sum(loss_u_list)
sample_size += len(targ_u_list)
if self.loss_weights is not None:
extra_losses = []
names = []
extra_losses.append(features_pen)
names.append("features_pen")
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
logging_output[f"loss_{n}"] = p.item()
logging_output = {
"loss": loss.item() if reduce else loss,
**logging_output,
}
# for lk in self.log_keys:
# if lk in net_output:
# logging_output[lk] = float((net_output[lk]))
def compute_correct(logits, target):
if logits.numel() == 0:
return 0, 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == target
min = logits.argmin(-1) == target
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
with torch.no_grad():
corr_m, count_m = compute_correct(logp_m_list, targ_m_list)
logging_output[f"correct_m_0"] = corr_m
logging_output[f"count_m_0"] = count_m
corr_u, count_u = compute_correct(logp_u_list, targ_u_list)
logging_output[f"correct_u_0"] = corr_u
logging_output[f"count_u_0"] = count_u
return loss, sample_size, logging_output

View File

@ -0,0 +1 @@
../../ASR/zipformer/joiner.py

View File

@ -0,0 +1,344 @@
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Zengwei Yao,
# Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
from scaling import ScaledLinear
from icefall.utils import add_sos
class AsrModel(nn.Module):
def __init__(
self,
encoder,
decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None,
encoder_dim: int = 768,
decoder_dim: int = 512,
vocab_size: int = 500,
use_transducer: bool = True,
use_ctc: bool = False,
):
"""A joint CTC & Transducer ASR model.
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
Args:
encoder:
It is the transcription network in the paper. Its accepts
inputs: `x` of (N, T, encoder_dim).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
It is used when use_transducer is True.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
It is used when use_transducer is True.
use_transducer:
Whether use transducer head. Default: True.
use_ctc:
Whether use CTC head. Default: False.
"""
super().__init__()
assert (
use_transducer or use_ctc
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
self.encoder = encoder
self.use_transducer = use_transducer
if use_transducer:
# Modules for Transducer head
assert decoder is not None
assert hasattr(decoder, "blank_id")
assert joiner is not None
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_scale=0.25
)
self.simple_lm_proj = ScaledLinear(
decoder_dim, vocab_size, initial_scale=0.25
)
else:
assert decoder is None
assert joiner is None
self.use_ctc = use_ctc
if use_ctc:
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
def forward_encoder(
self,
x: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 2-D tensor of shape (N, T).
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
if padding_mask is None:
padding_mask = torch.zeros_like(x, dtype=torch.bool)
encoder_out, padding_mask = self.encoder.extract_features(
source=x,
padding_mask=padding_mask,
mask=self.encoder.training,
)
encoder_out_lens = torch.sum(~padding_mask, dim=1)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
return encoder_out, encoder_out_lens
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=encoder_out_lens,
target_lengths=target_lengths,
reduction="sum",
)
return ctc_loss
def forward_transducer(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
y: k2.RaggedTensor,
y_lens: torch.Tensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
"""
# Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return simple_loss, pruned_loss
def forward(
self,
x: torch.Tensor,
y: k2.RaggedTensor,
padding_mask: Optional[torch.Tensor] = None,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 2-D tensor of shape (N, T).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 2, x.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == y.dim0, (x.shape, y.dim0)
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
if self.use_transducer:
# Compute transducer loss
simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y.to(x.device),
y_lens=y_lens,
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
)
else:
simple_loss = torch.empty(0)
pruned_loss = torch.empty(0)
if self.use_ctc:
# Compute CTC loss
targets = y.values
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
else:
ctc_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss, encoder_out_lens

View File

@ -0,0 +1 @@
../../ASR/zipformer/optim.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../hubert/ssl_datamodule.py

View File

@ -0,0 +1,337 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
from typing import Callable, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
def relu_squared(x: torch.Tensor):
return F.relu(x).pow(2)
def gelu_accurate(x):
if not hasattr(gelu_accurate, "_a"):
gelu_accurate._a = math.sqrt(2 / math.pi)
return (
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
)
def is_xla_tensor(tensor):
return torch.is_tensor(tensor) and tensor.device.type == "xla"
def index_put(tensor, indices, value):
if is_xla_tensor(tensor):
for _ in range(indices.dim(), tensor.dim()):
indices = indices.unsqueeze(-1)
if indices.size(-1) < tensor.size(-1):
indices = indices.expand_as(tensor)
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
else:
tensor[indices] = value
return tensor
def pad_to_multiple(x, multiple, dim=-1, value=0):
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
if x is None:
return None, 0
tsz = x.size(dim)
m = tsz / multiple
remainder = math.ceil(m) * multiple - tsz
if m.is_integer():
return x, 0
pad_offset = (0,) * (-1 - dim) * 2
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
def gelu(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x.float()).type_as(x)
def get_activation_fn(activation: str) -> Callable:
"""Returns the activation function corresponding to `activation`"""
if activation == "relu":
return F.relu
elif activation == "relu_squared":
return relu_squared
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
return gelu_accurate
elif activation == "gelu_accurate":
return gelu_accurate
elif activation == "tanh":
return torch.tanh
elif activation == "linear":
return lambda x: x
elif activation == "swish":
return torch.nn.SiLU
else:
raise RuntimeError("--activation-fn {} not supported".format(activation))
class SamePad(nn.Module):
def __init__(self, kernel_size, causal=False):
super().__init__()
if causal:
self.remove = kernel_size - 1
else:
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
class SamePad2d(nn.Module):
def __init__(self, kernel_size):
super().__init__()
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
assert len(x.size()) == 4
if self.remove > 0:
x = x[:, :, : -self.remove, : -self.remove]
return x
class TransposeLast(nn.Module):
def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
super().__init__()
self.deconstruct_idx = deconstruct_idx
self.tranpose_dim = tranpose_dim
def forward(self, x):
if self.deconstruct_idx is not None:
x = x[self.deconstruct_idx]
return x.transpose(self.tranpose_dim, -1)
try:
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
has_fused_layernorm = True
class FusedLayerNorm(_FusedLayerNorm):
@torch.jit.unused
def forward(self, x):
if not x.is_cuda:
return super().forward(x)
else:
with torch.cuda.device(x.device):
return super().forward(x)
except ImportError:
has_fused_layernorm = False
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
if torch.jit.is_scripting() or torch.jit.is_tracing():
export = True
if not export and torch.cuda.is_available() and has_fused_layernorm:
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
class Fp32LayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.layer_norm(
input.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
class Fp32GroupNorm(nn.GroupNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.group_norm(
input.float(),
self.num_groups,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
def softmax(x, dim: int, onnx_trace: bool = False):
if onnx_trace:
return F.softmax(x.float(), dim=dim)
else:
return F.softmax(x, dim=dim, dtype=torch.float32)
def quant_noise(module, p, block_size):
"""
Wraps modules and applies quantization noise to the weights for
subsequent quantization with Iterative Product Quantization as
described in "Training with Quantization Noise for Extreme Model Compression"
Args:
- module: nn.Module
- p: amount of Quantization Noise
- block_size: size of the blocks for subsequent quantization with iPQ
Remarks:
- Module weights must have the right sizes wrt the block size
- Only Linear, Embedding and Conv2d modules are supported for the moment
- For more detail on how to quantize by blocks with convolutional weights,
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
- We implement the simplest form of noise here as stated in the paper
which consists in randomly dropping blocks
"""
# if no quantization noise, don't register hook
if p <= 0:
return module
# supported modules
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
# test whether module.weight has the right sizes wrt block_size
is_conv = module.weight.ndim == 4
# 2D matrix
if not is_conv:
assert (
module.weight.size(1) % block_size == 0
), "Input features must be a multiple of block sizes"
# 4D matrix
else:
# 1x1 convolutions
if module.kernel_size == (1, 1):
assert (
module.in_channels % block_size == 0
), "Input channels must be a multiple of block sizes"
# regular convolutions
else:
k = module.kernel_size[0] * module.kernel_size[1]
assert k % block_size == 0, "Kernel size must be a multiple of block size"
def _forward_pre_hook(mod, input):
# no noise for evaluation
if mod.training:
if not is_conv:
# gather weight and sizes
weight = mod.weight
in_features = weight.size(1)
out_features = weight.size(0)
# split weight matrix into blocks and randomly drop selected blocks
mask = torch.zeros(
in_features // block_size * out_features, device=weight.device
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
else:
# gather weight and sizes
weight = mod.weight
in_channels = mod.in_channels
out_channels = mod.out_channels
# split weight matrix into blocks and randomly drop selected blocks
if mod.kernel_size == (1, 1):
mask = torch.zeros(
int(in_channels // block_size * out_channels),
device=weight.device,
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
else:
mask = torch.zeros(
weight.size(0), weight.size(1), device=weight.device
)
mask.bernoulli_(p)
mask = (
mask.unsqueeze(2)
.unsqueeze(3)
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
)
# scale weights and apply mask
mask = mask.to(
torch.bool
) # x.bool() is not currently supported in TorchScript
s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0)
module.register_forward_pre_hook(_forward_pre_hook)
return module
class FairseqDropout(nn.Module):
def __init__(self, p, module_name=None):
super().__init__()
self.p = p
self.module_name = module_name
self.apply_during_inference = False
def forward(self, x, inplace: bool = False):
if self.p > 0 and (self.training or self.apply_during_inference):
return F.dropout(x, p=self.p, training=True, inplace=inplace)
else:
return x
def make_generation_fast_(
self,
name: str,
retain_dropout: bool = False,
retain_dropout_modules: Optional[List[str]] = None,
**kwargs
):
if retain_dropout:
if retain_dropout_modules is not None and self.module_name is None:
pass
elif (
retain_dropout_modules is None # if None, apply to all modules
or self.module_name in retain_dropout_modules
):
self.apply_during_inference = True
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None

View File

@ -0,0 +1,108 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import Fp32GroupNorm, Fp32LayerNorm, TransposeLast
class ConvFeatureExtractionModel(nn.Module):
def __init__(
self,
conv_layers: List[Tuple[int, int, int]],
dropout: float = 0.0,
mode: str = "default",
conv_bias: bool = False,
):
super().__init__()
assert mode in {"default", "layer_norm"}
def block(
n_in,
n_out,
k,
stride,
is_layer_norm=False,
is_group_norm=False,
conv_bias=False,
):
def make_conv():
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
nn.init.kaiming_normal_(conv.weight)
return conv
assert (
is_layer_norm and is_group_norm
) == False, "layer norm and group norm are exclusive"
if is_layer_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
nn.Sequential(
TransposeLast(),
Fp32LayerNorm(dim, elementwise_affine=True),
TransposeLast(),
),
nn.GELU(),
)
elif is_group_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
Fp32GroupNorm(dim, dim, affine=True),
nn.GELU(),
)
else:
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
in_d = 1
self.conv_layers = nn.ModuleList()
for i, cl in enumerate(conv_layers):
assert len(cl) == 3, "invalid conv definition: " + str(cl)
(dim, k, stride) = cl
self.conv_layers.append(
block(
in_d,
dim,
k,
stride,
is_layer_norm=mode == "layer_norm",
is_group_norm=mode == "default" and i == 0,
conv_bias=conv_bias,
)
)
in_d = dim
def forward(self, x):
# BxT -> BxCxT
x = x.unsqueeze(1)
for conv in self.conv_layers:
x = conv(x)
return x

File diff suppressed because it is too large Load Diff