mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
* Add k2SSL * fix flake8 * fix for black * fix for black * fix for black * Update ssl_datamodule.py * Fix bugs in HubertDataset * update comments * add librilight * add checkpoint convert script * format --------- Co-authored-by: yifanyeung <yifanyeung@yifanyeung.local> Co-authored-by: zzasdf <15218404468@163.com>
339 lines
11 KiB
Python
339 lines
11 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
import math
|
|
from typing import Callable, List, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def relu_squared(x: torch.Tensor):
|
|
return F.relu(x).pow(2)
|
|
|
|
|
|
def gelu_accurate(x):
|
|
if not hasattr(gelu_accurate, "_a"):
|
|
gelu_accurate._a = math.sqrt(2 / math.pi)
|
|
return (
|
|
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
|
)
|
|
|
|
|
|
def is_xla_tensor(tensor):
|
|
return torch.is_tensor(tensor) and tensor.device.type == "xla"
|
|
|
|
|
|
def index_put(tensor, indices, value):
|
|
if is_xla_tensor(tensor):
|
|
for _ in range(indices.dim(), tensor.dim()):
|
|
indices = indices.unsqueeze(-1)
|
|
if indices.size(-1) < tensor.size(-1):
|
|
indices = indices.expand_as(tensor)
|
|
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
|
|
else:
|
|
tensor[indices] = value
|
|
return tensor
|
|
|
|
|
|
def pad_to_multiple(x, multiple, dim=-1, value=0):
|
|
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
|
|
if x is None:
|
|
return None, 0
|
|
tsz = x.size(dim)
|
|
m = tsz / multiple
|
|
remainder = math.ceil(m) * multiple - tsz
|
|
if m.is_integer():
|
|
return x, 0
|
|
pad_offset = (0,) * (-1 - dim) * 2
|
|
|
|
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
|
|
|
|
|
|
def gelu(x: torch.Tensor) -> torch.Tensor:
|
|
return torch.nn.functional.gelu(x.float()).type_as(x)
|
|
|
|
|
|
def get_activation_fn(activation: str) -> Callable:
|
|
"""Returns the activation function corresponding to `activation`"""
|
|
if activation == "relu":
|
|
return F.relu
|
|
elif activation == "relu_squared":
|
|
return relu_squared
|
|
elif activation == "gelu":
|
|
return gelu
|
|
elif activation == "gelu_fast":
|
|
return gelu_accurate
|
|
elif activation == "gelu_accurate":
|
|
return gelu_accurate
|
|
elif activation == "tanh":
|
|
return torch.tanh
|
|
elif activation == "linear":
|
|
return lambda x: x
|
|
elif activation == "swish":
|
|
return torch.nn.SiLU
|
|
else:
|
|
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
|
|
|
|
|
class SamePad(nn.Module):
|
|
def __init__(self, kernel_size, causal=False):
|
|
super().__init__()
|
|
if causal:
|
|
self.remove = kernel_size - 1
|
|
else:
|
|
self.remove = 1 if kernel_size % 2 == 0 else 0
|
|
|
|
def forward(self, x):
|
|
if self.remove > 0:
|
|
x = x[:, :, : -self.remove]
|
|
return x
|
|
|
|
|
|
class SamePad2d(nn.Module):
|
|
def __init__(self, kernel_size):
|
|
super().__init__()
|
|
self.remove = 1 if kernel_size % 2 == 0 else 0
|
|
|
|
def forward(self, x):
|
|
assert len(x.size()) == 4
|
|
if self.remove > 0:
|
|
x = x[:, :, : -self.remove, : -self.remove]
|
|
return x
|
|
|
|
|
|
class TransposeLast(nn.Module):
|
|
def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
|
|
super().__init__()
|
|
self.deconstruct_idx = deconstruct_idx
|
|
self.tranpose_dim = tranpose_dim
|
|
|
|
def forward(self, x):
|
|
if self.deconstruct_idx is not None:
|
|
x = x[self.deconstruct_idx]
|
|
return x.transpose(self.tranpose_dim, -1)
|
|
|
|
|
|
try:
|
|
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
|
|
|
|
has_fused_layernorm = True
|
|
|
|
class FusedLayerNorm(_FusedLayerNorm):
|
|
@torch.jit.unused
|
|
def forward(self, x):
|
|
if not x.is_cuda:
|
|
return super().forward(x)
|
|
else:
|
|
with torch.cuda.device(x.device):
|
|
return super().forward(x)
|
|
|
|
except ImportError:
|
|
has_fused_layernorm = False
|
|
|
|
|
|
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
|
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
|
export = True
|
|
if not export and torch.cuda.is_available() and has_fused_layernorm:
|
|
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
|
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
|
|
|
|
|
class Fp32LayerNorm(nn.LayerNorm):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def forward(self, input):
|
|
output = F.layer_norm(
|
|
input.float(),
|
|
self.normalized_shape,
|
|
self.weight.float() if self.weight is not None else None,
|
|
self.bias.float() if self.bias is not None else None,
|
|
self.eps,
|
|
)
|
|
return output.type_as(input)
|
|
|
|
|
|
class Fp32GroupNorm(nn.GroupNorm):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def forward(self, input):
|
|
output = F.group_norm(
|
|
input.float(),
|
|
self.num_groups,
|
|
self.weight.float() if self.weight is not None else None,
|
|
self.bias.float() if self.bias is not None else None,
|
|
self.eps,
|
|
)
|
|
return output.type_as(input)
|
|
|
|
|
|
def softmax(x, dim: int, onnx_trace: bool = False):
|
|
if onnx_trace:
|
|
return F.softmax(x.float(), dim=dim)
|
|
else:
|
|
return F.softmax(x, dim=dim, dtype=torch.float32)
|
|
|
|
|
|
def quant_noise(module, p, block_size):
|
|
"""
|
|
Wraps modules and applies quantization noise to the weights for
|
|
subsequent quantization with Iterative Product Quantization as
|
|
described in "Training with Quantization Noise for Extreme Model Compression"
|
|
|
|
Args:
|
|
- module: nn.Module
|
|
- p: amount of Quantization Noise
|
|
- block_size: size of the blocks for subsequent quantization with iPQ
|
|
|
|
Remarks:
|
|
- Module weights must have the right sizes wrt the block size
|
|
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
|
- For more detail on how to quantize by blocks with convolutional weights,
|
|
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
|
- We implement the simplest form of noise here as stated in the paper
|
|
which consists in randomly dropping blocks
|
|
"""
|
|
|
|
# if no quantization noise, don't register hook
|
|
if p <= 0:
|
|
return module
|
|
|
|
# supported modules
|
|
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
|
|
|
# test whether module.weight has the right sizes wrt block_size
|
|
is_conv = module.weight.ndim == 4
|
|
|
|
# 2D matrix
|
|
if not is_conv:
|
|
assert (
|
|
module.weight.size(1) % block_size == 0
|
|
), "Input features must be a multiple of block sizes"
|
|
|
|
# 4D matrix
|
|
else:
|
|
# 1x1 convolutions
|
|
if module.kernel_size == (1, 1):
|
|
assert (
|
|
module.in_channels % block_size == 0
|
|
), "Input channels must be a multiple of block sizes"
|
|
# regular convolutions
|
|
else:
|
|
k = module.kernel_size[0] * module.kernel_size[1]
|
|
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
|
|
|
def _forward_pre_hook(mod, input):
|
|
# no noise for evaluation
|
|
if mod.training:
|
|
if not is_conv:
|
|
# gather weight and sizes
|
|
weight = mod.weight
|
|
in_features = weight.size(1)
|
|
out_features = weight.size(0)
|
|
|
|
# split weight matrix into blocks and randomly drop selected blocks
|
|
mask = torch.zeros(
|
|
in_features // block_size * out_features,
|
|
device=weight.device,
|
|
)
|
|
mask.bernoulli_(p)
|
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
|
|
|
else:
|
|
# gather weight and sizes
|
|
weight = mod.weight
|
|
in_channels = mod.in_channels
|
|
out_channels = mod.out_channels
|
|
|
|
# split weight matrix into blocks and randomly drop selected blocks
|
|
if mod.kernel_size == (1, 1):
|
|
mask = torch.zeros(
|
|
int(in_channels // block_size * out_channels),
|
|
device=weight.device,
|
|
)
|
|
mask.bernoulli_(p)
|
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
|
else:
|
|
mask = torch.zeros(
|
|
weight.size(0), weight.size(1), device=weight.device
|
|
)
|
|
mask.bernoulli_(p)
|
|
mask = (
|
|
mask.unsqueeze(2)
|
|
.unsqueeze(3)
|
|
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
|
)
|
|
|
|
# scale weights and apply mask
|
|
mask = mask.to(
|
|
torch.bool
|
|
) # x.bool() is not currently supported in TorchScript
|
|
s = 1 / (1 - p)
|
|
mod.weight.data = s * weight.masked_fill(mask, 0)
|
|
|
|
module.register_forward_pre_hook(_forward_pre_hook)
|
|
return module
|
|
|
|
|
|
class FairseqDropout(nn.Module):
|
|
def __init__(self, p, module_name=None):
|
|
super().__init__()
|
|
self.p = p
|
|
self.module_name = module_name
|
|
self.apply_during_inference = False
|
|
|
|
def forward(self, x, inplace: bool = False):
|
|
if self.p > 0 and (self.training or self.apply_during_inference):
|
|
return F.dropout(x, p=self.p, training=True, inplace=inplace)
|
|
else:
|
|
return x
|
|
|
|
def make_generation_fast_(
|
|
self,
|
|
name: str,
|
|
retain_dropout: bool = False,
|
|
retain_dropout_modules: Optional[List[str]] = None,
|
|
**kwargs
|
|
):
|
|
if retain_dropout:
|
|
if retain_dropout_modules is not None and self.module_name is None:
|
|
pass
|
|
elif (
|
|
retain_dropout_modules is None # if None, apply to all modules
|
|
or self.module_name in retain_dropout_modules
|
|
):
|
|
self.apply_during_inference = True
|
|
|
|
|
|
class GradMultiply(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, scale):
|
|
ctx.scale = scale
|
|
res = x.new(x)
|
|
return res
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad):
|
|
return grad * ctx.scale, None
|