mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
* Add k2SSL * fix flake8 * fix for black * fix for black * fix for black * Update ssl_datamodule.py * Fix bugs in HubertDataset * update comments * add librilight * add checkpoint convert script * format --------- Co-authored-by: yifanyeung <yifanyeung@yifanyeung.local> Co-authored-by: zzasdf <15218404468@163.com>
594 lines
20 KiB
Python
594 lines
20 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
import math
|
|
from typing import List, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from attention_module import MultiheadAttention, init_bert_params
|
|
from utils import (
|
|
Fp32GroupNorm,
|
|
Fp32LayerNorm,
|
|
LayerNorm,
|
|
SamePad,
|
|
TransposeLast,
|
|
get_activation_fn,
|
|
index_put,
|
|
pad_to_multiple,
|
|
)
|
|
|
|
|
|
class ConvFeatureExtractionModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
conv_layers: List[Tuple[int, int, int]],
|
|
dropout: float = 0.0,
|
|
mode: str = "default",
|
|
conv_bias: bool = False,
|
|
):
|
|
super().__init__()
|
|
|
|
assert mode in {"default", "layer_norm"}
|
|
|
|
def block(
|
|
n_in,
|
|
n_out,
|
|
k,
|
|
stride,
|
|
is_layer_norm=False,
|
|
is_group_norm=False,
|
|
conv_bias=False,
|
|
):
|
|
def make_conv():
|
|
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
|
nn.init.kaiming_normal_(conv.weight)
|
|
return conv
|
|
|
|
assert (
|
|
is_layer_norm and is_group_norm
|
|
) == False, "layer norm and group norm are exclusive"
|
|
|
|
if is_layer_norm:
|
|
return nn.Sequential(
|
|
make_conv(),
|
|
nn.Dropout(p=dropout),
|
|
nn.Sequential(
|
|
TransposeLast(),
|
|
Fp32LayerNorm(dim, elementwise_affine=True),
|
|
TransposeLast(),
|
|
),
|
|
nn.GELU(),
|
|
)
|
|
elif is_group_norm:
|
|
return nn.Sequential(
|
|
make_conv(),
|
|
nn.Dropout(p=dropout),
|
|
Fp32GroupNorm(dim, dim, affine=True),
|
|
nn.GELU(),
|
|
)
|
|
else:
|
|
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
|
|
|
in_d = 1
|
|
self.conv_layers = nn.ModuleList()
|
|
for i, cl in enumerate(conv_layers):
|
|
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
|
(dim, k, stride) = cl
|
|
|
|
self.conv_layers.append(
|
|
block(
|
|
in_d,
|
|
dim,
|
|
k,
|
|
stride,
|
|
is_layer_norm=mode == "layer_norm",
|
|
is_group_norm=mode == "default" and i == 0,
|
|
conv_bias=conv_bias,
|
|
)
|
|
)
|
|
in_d = dim
|
|
|
|
def forward(self, x):
|
|
# BxT -> BxCxT
|
|
x = x.unsqueeze(1)
|
|
|
|
for conv in self.conv_layers:
|
|
x = conv(x)
|
|
|
|
return x
|
|
|
|
|
|
def make_conv_pos(e, k, g, is_batch_norm=False):
|
|
pos_conv = nn.Conv1d(
|
|
e,
|
|
e,
|
|
kernel_size=k,
|
|
padding=k // 2,
|
|
groups=g,
|
|
)
|
|
dropout = 0
|
|
std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
|
|
nn.init.normal_(pos_conv.weight, mean=0, std=std)
|
|
nn.init.constant_(pos_conv.bias, 0)
|
|
|
|
if not is_batch_norm:
|
|
pos_conv = nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2)
|
|
pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU())
|
|
else:
|
|
batch_norm = nn.BatchNorm1d(e)
|
|
pos_conv = nn.Sequential(batch_norm, pos_conv, SamePad(k), nn.GELU())
|
|
|
|
return pos_conv
|
|
|
|
|
|
class TransformerEncoder(nn.Module):
|
|
def build_encoder_layer(self, args, **kwargs):
|
|
if args.layer_type == "transformer":
|
|
layer = TransformerSentenceEncoderLayer(
|
|
embedding_dim=self.embedding_dim,
|
|
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
|
num_attention_heads=args.encoder_attention_heads,
|
|
dropout=self.dropout,
|
|
attention_dropout=args.attention_dropout,
|
|
activation_dropout=args.activation_dropout,
|
|
activation_fn=args.activation_fn,
|
|
layer_norm_first=args.layer_norm_first,
|
|
)
|
|
elif args.layer_type == "trf_adp":
|
|
use_adp = False
|
|
if args.adp_trf_idx == "all":
|
|
use_adp = True
|
|
else:
|
|
adp_trf_idx = list(
|
|
range(*[int(g) for g in args.adp_trf_idx.split(":")])
|
|
)
|
|
if kwargs.get("layer_idx", None) in adp_trf_idx:
|
|
use_adp = True
|
|
if use_adp:
|
|
layer = TransformerSentenceEncoderWithAdapterLayer(
|
|
embedding_dim=self.embedding_dim,
|
|
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
|
num_attention_heads=args.encoder_attention_heads,
|
|
dropout=self.dropout,
|
|
attention_dropout=args.attention_dropout,
|
|
activation_dropout=args.activation_dropout,
|
|
activation_fn=args.activation_fn,
|
|
layer_norm_first=args.layer_norm_first,
|
|
adapter_num=args.adp_num,
|
|
adapter_dim=args.adp_dim,
|
|
adapter_act_fn=args.adp_act_fn,
|
|
)
|
|
else:
|
|
layer = TransformerSentenceEncoderLayer(
|
|
embedding_dim=self.embedding_dim,
|
|
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
|
num_attention_heads=args.encoder_attention_heads,
|
|
dropout=self.dropout,
|
|
attention_dropout=args.attention_dropout,
|
|
activation_dropout=args.activation_dropout,
|
|
activation_fn=args.activation_fn,
|
|
layer_norm_first=args.layer_norm_first,
|
|
)
|
|
|
|
# layer = fsdp_wrap(layer)
|
|
# if args.checkpoint_activations:
|
|
# layer = checkpoint_wrapper(layer)
|
|
return layer
|
|
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
|
|
self.dropout = args.dropout
|
|
self.embedding_dim = args.encoder_embed_dim
|
|
self.required_seq_len_multiple = args.required_seq_len_multiple
|
|
|
|
pos_conv_depth = getattr(args, "pos_conv_depth", 1)
|
|
if pos_conv_depth > 1:
|
|
num_layers = args.pos_conv_depth
|
|
k = max(3, args.conv_pos // num_layers)
|
|
|
|
def make_conv_block(e, k, g, l):
|
|
return nn.Sequential(
|
|
*[
|
|
nn.Sequential(
|
|
nn.Conv1d(
|
|
e,
|
|
e,
|
|
kernel_size=k,
|
|
padding=k // 2,
|
|
groups=g,
|
|
),
|
|
SamePad(k),
|
|
TransposeLast(),
|
|
LayerNorm(e, elementwise_affine=False),
|
|
TransposeLast(),
|
|
nn.GELU(),
|
|
)
|
|
for _ in range(l)
|
|
]
|
|
)
|
|
|
|
self.pos_conv = make_conv_block(
|
|
self.embedding_dim, k, args.conv_pos_groups, num_layers
|
|
)
|
|
|
|
else:
|
|
self.pos_conv = make_conv_pos(
|
|
self.embedding_dim,
|
|
args.conv_pos,
|
|
args.conv_pos_groups,
|
|
is_batch_norm=args.conv_pos_batch_norm
|
|
if hasattr(args, "conv_pos_batch_norm")
|
|
else False,
|
|
)
|
|
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
self.build_encoder_layer(args, layer_idx=ii)
|
|
for ii in range(args.encoder_layers)
|
|
]
|
|
)
|
|
self.layer_norm_first = args.layer_norm_first
|
|
self.layer_norm = LayerNorm(self.embedding_dim)
|
|
self.layerdrop = args.encoder_layerdrop
|
|
|
|
self.apply(init_bert_params)
|
|
|
|
def forward(self, x, padding_mask=None, layer=None, corpus_key=None):
|
|
x, layer_results = self.extract_features(
|
|
x, padding_mask, layer, corpus_key=corpus_key
|
|
)
|
|
|
|
if self.layer_norm_first and layer is None:
|
|
x = self.layer_norm(x)
|
|
|
|
return x, layer_results
|
|
|
|
def extract_features(
|
|
self,
|
|
x,
|
|
padding_mask=None,
|
|
tgt_layer=None,
|
|
min_layer=0,
|
|
corpus_key=None,
|
|
):
|
|
if padding_mask is not None:
|
|
x = index_put(x, padding_mask, 0)
|
|
|
|
x_conv = self.pos_conv(x.transpose(1, 2))
|
|
x_conv = x_conv.transpose(1, 2)
|
|
x = x + x_conv
|
|
|
|
if not self.layer_norm_first:
|
|
x = self.layer_norm(x)
|
|
|
|
# pad to the sequence length dimension
|
|
x, pad_length = pad_to_multiple(
|
|
x, self.required_seq_len_multiple, dim=-2, value=0
|
|
)
|
|
if pad_length > 0 and padding_mask is None:
|
|
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
|
|
padding_mask[:, -pad_length:] = True
|
|
else:
|
|
padding_mask, _ = pad_to_multiple(
|
|
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
|
|
)
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
|
|
# B x T x C -> T x B x C
|
|
x = x.transpose(0, 1)
|
|
|
|
layer_results = []
|
|
r = None
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
|
if not self.training or (dropout_probability > self.layerdrop):
|
|
layer_check = layer
|
|
# if isinstance(layer, FullyShardedDataParallel):
|
|
# layer_check = layer.unwrapped_module
|
|
if (corpus_key is None) or (
|
|
not isinstance(
|
|
layer_check,
|
|
(TransformerSentenceEncoderWithAdapterLayer,),
|
|
)
|
|
):
|
|
x, (z, lr) = layer(
|
|
x,
|
|
self_attn_padding_mask=padding_mask,
|
|
need_weights=False,
|
|
)
|
|
else:
|
|
x, (z, lr) = layer(
|
|
x,
|
|
self_attn_padding_mask=padding_mask,
|
|
need_weights=False,
|
|
corpus_key=corpus_key,
|
|
)
|
|
if i >= min_layer:
|
|
layer_results.append((x, z, lr))
|
|
if i == tgt_layer:
|
|
r = x
|
|
break
|
|
|
|
if r is not None:
|
|
x = r
|
|
|
|
# T x B x C -> B x T x C
|
|
x = x.transpose(0, 1)
|
|
|
|
# undo paddding
|
|
if pad_length > 0:
|
|
x = x[:, :-pad_length]
|
|
|
|
def undo_pad(a, b, c):
|
|
return (
|
|
a[:-pad_length],
|
|
b[:-pad_length] if b is not None else b,
|
|
c[:-pad_length],
|
|
)
|
|
|
|
layer_results = [undo_pad(*u) for u in layer_results]
|
|
|
|
return x, layer_results
|
|
|
|
def max_positions(self):
|
|
"""Maximum output length supported by the encoder."""
|
|
return self.args.max_positions
|
|
|
|
def upgrade_state_dict_named(self, state_dict, name):
|
|
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
|
return state_dict
|
|
|
|
|
|
class TransformerSentenceEncoderLayer(nn.Module):
|
|
"""
|
|
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
|
models.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding_dim: float = 768,
|
|
ffn_embedding_dim: float = 3072,
|
|
num_attention_heads: int = 8,
|
|
dropout: float = 0.1,
|
|
attention_dropout: float = 0.1,
|
|
activation_dropout: float = 0.1,
|
|
activation_fn: str = "relu",
|
|
layer_norm_first: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
# Initialize parameters
|
|
self.embedding_dim = embedding_dim
|
|
self.dropout = dropout
|
|
self.activation_dropout = activation_dropout
|
|
|
|
# Initialize blocks
|
|
self.activation_fn = get_activation_fn(activation_fn)
|
|
self.self_attn = MultiheadAttention(
|
|
self.embedding_dim,
|
|
num_attention_heads,
|
|
dropout=attention_dropout,
|
|
self_attention=True,
|
|
)
|
|
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.dropout2 = nn.Dropout(self.activation_dropout)
|
|
self.dropout3 = nn.Dropout(dropout)
|
|
|
|
self.layer_norm_first = layer_norm_first
|
|
|
|
# layer norm associated with the self attention layer
|
|
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
|
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
|
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
|
|
|
# layer norm associated with the position wise feed-forward NN
|
|
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
self_attn_mask: torch.Tensor = None,
|
|
self_attn_padding_mask: torch.Tensor = None,
|
|
need_weights: bool = False,
|
|
att_args=None,
|
|
):
|
|
"""
|
|
LayerNorm is applied either before or after the self-attention/ffn
|
|
modules similar to the original Transformer imlementation.
|
|
"""
|
|
residual = x
|
|
|
|
if self.layer_norm_first:
|
|
x = self.self_attn_layer_norm(x)
|
|
x, attn = self.self_attn(
|
|
query=x,
|
|
key=x,
|
|
value=x,
|
|
key_padding_mask=self_attn_padding_mask,
|
|
attn_mask=self_attn_mask,
|
|
need_weights=False,
|
|
)
|
|
x = self.dropout1(x)
|
|
x = residual + x
|
|
|
|
residual = x
|
|
x = self.final_layer_norm(x)
|
|
x = self.activation_fn(self.fc1(x))
|
|
x = self.dropout2(x)
|
|
x = self.fc2(x)
|
|
|
|
layer_result = x
|
|
|
|
x = self.dropout3(x)
|
|
x = residual + x
|
|
else:
|
|
x, attn = self.self_attn(
|
|
query=x,
|
|
key=x,
|
|
value=x,
|
|
key_padding_mask=self_attn_padding_mask,
|
|
need_weights=False,
|
|
)
|
|
|
|
x = self.dropout1(x)
|
|
x = residual + x
|
|
|
|
x = self.self_attn_layer_norm(x)
|
|
|
|
residual = x
|
|
x = self.activation_fn(self.fc1(x))
|
|
x = self.dropout2(x)
|
|
x = self.fc2(x)
|
|
|
|
layer_result = x
|
|
|
|
x = self.dropout3(x)
|
|
x = residual + x
|
|
x = self.final_layer_norm(x)
|
|
|
|
return x, (attn, layer_result)
|
|
|
|
|
|
class AdapterFast(nn.Module):
|
|
def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
|
|
"""
|
|
Implements adapter modules directly with 3D tensor weight as parameters
|
|
and without using ModuleList orto speed up training throughput.
|
|
"""
|
|
super().__init__()
|
|
|
|
self.adapter_num = adapter_num
|
|
self.input_dim = input_dim
|
|
self.hidden_dim = hidden_dim
|
|
self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim))
|
|
self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim))
|
|
self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim))
|
|
self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim))
|
|
|
|
self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim))
|
|
self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim))
|
|
self.act_fn = nn.Identity()
|
|
if act_fn == "relu":
|
|
self.act_fn = nn.ReLU()
|
|
elif act_fn == "gelu":
|
|
self.act_fn = nn.GELU()
|
|
elif act_fn == "selu":
|
|
self.act_fn = nn.SELU()
|
|
else:
|
|
raise ValueError(f"unsupported {act_fn}")
|
|
|
|
self.input_dim = input_dim
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
for ii in range(self.adapter_num):
|
|
nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5))
|
|
nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5))
|
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii])
|
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
nn.init.uniform_(self.b_a[ii], -bound, bound)
|
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
|
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
nn.init.uniform_(self.b_b[ii], -bound, bound)
|
|
|
|
nn.init.ones_(self.ln_W)
|
|
nn.init.zeros_(self.ln_b)
|
|
|
|
def forward(self, x, adapter_id):
|
|
ii = adapter_id
|
|
h = x
|
|
h = F.layer_norm(h, (self.input_dim,), self.ln_W[ii], self.ln_b[ii])
|
|
h = F.linear(h, self.W_a[ii], self.b_a[ii])
|
|
h = self.act_fn(h)
|
|
h = F.linear(h, self.W_b[ii], self.b_b[ii])
|
|
outputs = h
|
|
return outputs
|
|
|
|
def extra_repr(self):
|
|
return "adapter={}, input_dim={}, hidden_dim={}".format(
|
|
self.adapter_num, self.input_dim, self.hidden_dim
|
|
)
|
|
|
|
|
|
class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer):
|
|
"""
|
|
Implements a Transformer Encoder Layer with adapters used in BERT/XLM style pre-trained
|
|
models. An adapter module is added along with vanilla Transformer module.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding_dim: float = 768,
|
|
ffn_embedding_dim: float = 3072,
|
|
num_attention_heads: int = 8,
|
|
dropout: float = 0.1,
|
|
attention_dropout: float = 0.1,
|
|
activation_dropout: float = 0.1,
|
|
activation_fn: str = "relu",
|
|
layer_norm_first: bool = False,
|
|
adapter_num=201,
|
|
adapter_dim=64,
|
|
adapter_act_fn="relu",
|
|
) -> None:
|
|
super().__init__(
|
|
embedding_dim=embedding_dim,
|
|
ffn_embedding_dim=ffn_embedding_dim,
|
|
num_attention_heads=num_attention_heads,
|
|
dropout=dropout,
|
|
attention_dropout=attention_dropout,
|
|
activation_dropout=activation_dropout,
|
|
activation_fn=activation_fn,
|
|
layer_norm_first=layer_norm_first,
|
|
)
|
|
|
|
self.adapter_num = adapter_num
|
|
self.adapter_dim = adapter_dim
|
|
self.adapter_layer = AdapterFast(
|
|
adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
self_attn_mask: torch.Tensor = None,
|
|
self_attn_padding_mask: torch.Tensor = None,
|
|
need_weights: bool = False,
|
|
att_args=None,
|
|
corpus_key=None,
|
|
):
|
|
x, (attn, layer_result) = super().forward(
|
|
x=x,
|
|
self_attn_mask=self_attn_mask,
|
|
self_attn_padding_mask=self_attn_padding_mask,
|
|
need_weights=need_weights,
|
|
att_args=att_args,
|
|
)
|
|
assert corpus_key is not None
|
|
assert len(set(corpus_key)) == 1, f"corpus_key items are not same {corpus_key}"
|
|
y = self.adapter_layer(x, corpus_key[0])
|
|
x = x + y
|
|
return x, (attn, layer_result)
|