mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
* Add k2SSL * fix flake8 * fix for black * fix for black * fix for black * Update ssl_datamodule.py * Fix bugs in HubertDataset * update comments * add librilight * add checkpoint convert script * format --------- Co-authored-by: yifanyeung <yifanyeung@yifanyeung.local> Co-authored-by: zzasdf <15218404468@163.com>
941 lines
30 KiB
Python
941 lines
30 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
import argparse
|
|
import logging
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from utils import GradMultiply, LayerNorm
|
|
from wav2vec2_module import ConvFeatureExtractionModel, TransformerEncoder
|
|
|
|
|
|
def compute_mask_indices(
|
|
shape: Tuple[int, int],
|
|
padding_mask: Optional[torch.Tensor],
|
|
mask_prob: float,
|
|
mask_length: int,
|
|
mask_type: str = "static",
|
|
mask_other: float = 0.0,
|
|
min_masks: int = 0,
|
|
no_overlap: bool = False,
|
|
min_space: int = 0,
|
|
require_same_masks: bool = True,
|
|
mask_dropout: float = 0.0,
|
|
add_masks: bool = False,
|
|
seed: Optional[int] = None,
|
|
epoch: Optional[int] = None,
|
|
indices: Optional[torch.Tensor] = None,
|
|
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
|
|
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
|
|
) -> np.ndarray:
|
|
"""
|
|
Computes random mask spans for a given shape
|
|
|
|
Args:
|
|
shape: the the shape for which to compute masks.
|
|
should be of size 2 where first element is batch size and 2nd is timesteps
|
|
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
|
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
|
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
|
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
|
mask_type: how to compute mask lengths
|
|
static = fixed size
|
|
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
|
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
|
poisson = sample from possion distribution with lambda = mask length
|
|
min_masks: minimum number of masked spans
|
|
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
|
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
|
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
|
mask_dropout: randomly dropout this percentage of masks in each example
|
|
"""
|
|
|
|
bsz, all_sz = shape
|
|
mask = np.full((bsz, all_sz), False)
|
|
|
|
if num_mask_ver == 1:
|
|
all_num_mask = int(
|
|
# add a random number for probabilistic rounding
|
|
mask_prob * all_sz / float(mask_length)
|
|
+ np.random.rand()
|
|
)
|
|
all_num_mask = max(min_masks, all_num_mask)
|
|
|
|
mask_idcs = []
|
|
for i in range(bsz):
|
|
if seed is not None and epoch is not None and indices is not None:
|
|
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
|
|
else:
|
|
seed_i = None
|
|
|
|
rng = np.random.default_rng(seed_i)
|
|
|
|
if padding_mask is not None:
|
|
sz = all_sz - padding_mask[i].long().sum().item()
|
|
assert sz >= 0, sz
|
|
else:
|
|
sz = all_sz
|
|
|
|
if num_mask_ver == 1:
|
|
if padding_mask is not None:
|
|
num_mask = int(
|
|
# add a random number for probabilistic rounding
|
|
mask_prob * sz / float(mask_length)
|
|
+ np.random.rand()
|
|
)
|
|
num_mask = max(min_masks, num_mask)
|
|
else:
|
|
num_mask = all_num_mask
|
|
elif num_mask_ver == 2:
|
|
num_mask = int(
|
|
# add a random number for probabilistic rounding
|
|
mask_prob * sz / float(mask_length)
|
|
+ rng.random()
|
|
)
|
|
num_mask = max(min_masks, num_mask)
|
|
else:
|
|
raise ValueError()
|
|
|
|
if mask_type == "static":
|
|
lengths = np.full(num_mask, mask_length)
|
|
elif mask_type == "uniform":
|
|
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
|
elif mask_type == "normal":
|
|
lengths = rng.normal(mask_length, mask_other, size=num_mask)
|
|
lengths = [max(1, int(round(x))) for x in lengths]
|
|
elif mask_type == "poisson":
|
|
lengths = rng.poisson(mask_length, size=num_mask)
|
|
lengths = [int(round(x)) for x in lengths]
|
|
else:
|
|
raise Exception("unknown mask selection " + mask_type)
|
|
|
|
if sum(lengths) == 0:
|
|
if mask_type == "static":
|
|
raise ValueError(f"this should never happens")
|
|
else:
|
|
lengths = [min(mask_length, sz - 1)]
|
|
|
|
if no_overlap:
|
|
mask_idc = []
|
|
|
|
def arrange(s, e, length, keep_length):
|
|
span_start = rng.randint(s, e - length)
|
|
mask_idc.extend(span_start + i for i in range(length))
|
|
|
|
new_parts = []
|
|
if span_start - s - min_space >= keep_length:
|
|
new_parts.append((s, span_start - min_space + 1))
|
|
if e - span_start - length - min_space > keep_length:
|
|
new_parts.append((span_start + length + min_space, e))
|
|
return new_parts
|
|
|
|
parts = [(0, sz)]
|
|
min_length = min(lengths)
|
|
for length in sorted(lengths, reverse=True):
|
|
lens = np.fromiter(
|
|
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
|
np.int,
|
|
)
|
|
l_sum = np.sum(lens)
|
|
if l_sum == 0:
|
|
break
|
|
probs = lens / np.sum(lens)
|
|
c = rng.choice(len(parts), p=probs)
|
|
s, e = parts.pop(c)
|
|
parts.extend(arrange(s, e, length, min_length))
|
|
mask_idc = np.asarray(mask_idc)
|
|
else:
|
|
if idc_select_ver == 1:
|
|
min_len = min(lengths)
|
|
if sz - min_len <= num_mask:
|
|
min_len = sz - num_mask - 1
|
|
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
|
|
elif idc_select_ver == 2:
|
|
mask_idc = rng.choice(sz, num_mask, replace=False)
|
|
else:
|
|
raise ValueError()
|
|
|
|
mask_idc = np.asarray(
|
|
[
|
|
mask_idc[j] + offset
|
|
for j in range(len(mask_idc))
|
|
for offset in range(lengths[j])
|
|
]
|
|
)
|
|
|
|
mask_idc = np.unique(mask_idc[mask_idc < sz])
|
|
if len(mask_idc) >= sz:
|
|
raise ValueError(
|
|
(
|
|
f"the entire sequence is masked. "
|
|
f"sz={sz}; mask_idc[mask_idc]; "
|
|
f"index={indices[i] if indices is not None else None}"
|
|
)
|
|
)
|
|
mask_idcs.append(mask_idc)
|
|
|
|
target_len = None
|
|
if require_same_masks:
|
|
if add_masks:
|
|
target_len = max([len(m) for m in mask_idcs])
|
|
else:
|
|
target_len = min([len(m) for m in mask_idcs])
|
|
|
|
for i, mask_idc in enumerate(mask_idcs):
|
|
if target_len is not None and len(mask_idc) > target_len:
|
|
mask_idc = rng.choice(mask_idc, target_len, replace=False)
|
|
|
|
mask[i, mask_idc] = True
|
|
|
|
if target_len is not None and len(mask_idc) < target_len:
|
|
unmasked = np.flatnonzero(~mask[i])
|
|
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
|
|
mask[i, to_mask] = True
|
|
|
|
if mask_dropout > 0:
|
|
masked = np.flatnonzero(mask[i])
|
|
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
|
|
to_drop = rng.choice(masked, num_holes, replace=False)
|
|
mask[i, to_drop] = False
|
|
|
|
return mask
|
|
|
|
|
|
def add_hubert_arguments(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--label-rate",
|
|
type=float,
|
|
default=50,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--sample-rate",
|
|
type=float,
|
|
default=16000,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--extractor-mode",
|
|
type=str,
|
|
default="default",
|
|
help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group
|
|
norm with d groups in the first conv block, whereas layer_norm
|
|
has layer norms in every block (meant to use with normalize=True)""",
|
|
)
|
|
parser.add_argument(
|
|
"--encoder-layers",
|
|
type=int,
|
|
default=12,
|
|
help="num encoder layers in the transformer",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--encoder-embed-dim",
|
|
type=int,
|
|
default=768,
|
|
help="encoder embedding dimension",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--encoder-ffn-embed-dim",
|
|
type=int,
|
|
default=3072,
|
|
help="encoder embedding dimension for FFN",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--encoder-attention-heads",
|
|
type=int,
|
|
default=12,
|
|
help="num encoder attention heads",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--activation-fn",
|
|
type=str,
|
|
choices=[
|
|
"relu",
|
|
"gelu",
|
|
"gelu_fast",
|
|
"gelu_accurate",
|
|
"tanh",
|
|
"linear",
|
|
],
|
|
default="gelu",
|
|
help="activation function to use",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--layer-type",
|
|
type=str,
|
|
choices=["transformer", "conformer", "trf_adp"],
|
|
default="transformer",
|
|
help="layer type in encoder",
|
|
)
|
|
|
|
# dropouts
|
|
parser.add_argument(
|
|
"--dropout",
|
|
type=float,
|
|
default=0.1,
|
|
help="dropout probability for the transformer",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--attention-dropout",
|
|
type=float,
|
|
default=0.1,
|
|
help="dropout probability for attention weights",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--activation-dropout",
|
|
type=float,
|
|
default=0.0,
|
|
help="dropout probability after activation in FFN",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--encoder-layerdrop",
|
|
type=float,
|
|
default=0.0,
|
|
help="probability of dropping a tarnsformer layer",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--dropout-input",
|
|
type=float,
|
|
default=0.0,
|
|
help="dropout to apply to the input (after feat extr)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--dropout-features",
|
|
type=float,
|
|
default=0.0,
|
|
help="dropout to apply to the features (after feat extr)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--final-dim",
|
|
type=int,
|
|
default=0,
|
|
help="project final representations and targets to this many dimensions. set to encoder_embed_dim is <= 0",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--untie-final-proj",
|
|
type=bool,
|
|
default=False,
|
|
help="use separate projection for each target",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--layer-norm-first",
|
|
type=bool,
|
|
default=False,
|
|
help="apply layernorm first in the transformer",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--conv-feature-layers",
|
|
type=str,
|
|
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
|
|
help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--conv-bias",
|
|
type=bool,
|
|
default=False,
|
|
help="include bias in conv encoder",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--logit-temp",
|
|
type=float,
|
|
default=0.1,
|
|
help="temperature to divide logits by",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--target-glu",
|
|
type=bool,
|
|
default=False,
|
|
help="adds projection + glu to targets",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--feature-grad-mult",
|
|
type=float,
|
|
default=1.0,
|
|
help="multiply feature extractor var grads by this",
|
|
)
|
|
|
|
# masking
|
|
parser.add_argument("--mask-length", type=int, default=10, help="mask_length")
|
|
|
|
parser.add_argument(
|
|
"--mask-prob",
|
|
type=float,
|
|
default=0.65,
|
|
help="probability of replacing a token with mask",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--mask-selection",
|
|
type=str,
|
|
choices=["static", "uniform", "normal", "poisson"],
|
|
default="static",
|
|
help="how to choose mask length",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--mask-other",
|
|
type=float,
|
|
default=0,
|
|
help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--no-mask-overlap",
|
|
type=bool,
|
|
default=False,
|
|
help="whether to allow masks to overlap",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--mask-min-space",
|
|
type=int,
|
|
default=1,
|
|
help="min space between spans (if no overlap is enabled)",
|
|
)
|
|
|
|
# channel masking
|
|
parser.add_argument(
|
|
"--mask-channel-length",
|
|
type=int,
|
|
default=10,
|
|
help="length of the mask for features (channels)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--mask-channel-prob",
|
|
type=float,
|
|
default=0.0,
|
|
help="probability of replacing a feature with 0",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--mask-channel-selection",
|
|
type=str,
|
|
choices=["static", "uniform", "normal", "poisson"],
|
|
default="static",
|
|
help="how to choose mask length for channel masking",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--mask-channel-other",
|
|
type=float,
|
|
default=0,
|
|
help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--no-mask-channel-overlap",
|
|
type=bool,
|
|
default=False,
|
|
help="whether to allow channel masks to overlap",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--mask-channel-min-space",
|
|
type=int,
|
|
default=1,
|
|
help="min space between spans (if no overlap is enabled)",
|
|
)
|
|
|
|
# positional embeddings
|
|
parser.add_argument(
|
|
"--conv-pos",
|
|
type=int,
|
|
default=128,
|
|
help="number of filters for convolutional positional embeddings",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--conv-pos-groups",
|
|
type=int,
|
|
default=16,
|
|
help="number of groups for convolutional positional embedding",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--conv-pos-batch-norm",
|
|
type=bool,
|
|
default=False,
|
|
help="use batch norm instead of weight norm in conv_pos (for bf16 models)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--latent-temp",
|
|
type=float,
|
|
nargs="*",
|
|
default=[2, 0.5, 0.999995],
|
|
help="legacy (to be removed)",
|
|
)
|
|
|
|
# loss computation
|
|
parser.add_argument(
|
|
"--skip-masked",
|
|
type=bool,
|
|
default=False,
|
|
help="skip computing losses over masked frames",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--skip-nomask",
|
|
type=bool,
|
|
default=False,
|
|
help="skip computing losses over unmasked frames",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--checkpoint-activations",
|
|
type=bool,
|
|
default=False,
|
|
help="recompute activations and save memory for extra compute",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--pred-masked-weight",
|
|
type=float,
|
|
default=1,
|
|
help="weight for masked part in ssl loss",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--pred-nomask-weight",
|
|
type=float,
|
|
default=0,
|
|
help="weight for masked part in ssl loss",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--loss-weights",
|
|
type=float,
|
|
nargs="*",
|
|
default=[10],
|
|
help="weight for masked part in ssl loss",
|
|
)
|
|
|
|
# FP16 optimization
|
|
parser.add_argument(
|
|
"--required-seq-len-multiple",
|
|
type=int,
|
|
default=2,
|
|
help="pad the input to encoder such that the sequence length is divisible by multiple",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--attn-type", type=str, default="", help="if espnet use ESPNET MHA"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--pos-enc-type",
|
|
type=str,
|
|
default="abs",
|
|
help="Positional encoding type to use in conformer",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--num-classes",
|
|
type=int,
|
|
nargs="*",
|
|
default=[504],
|
|
help="""num class, a little larger than the number of cluster,
|
|
the largest is for padding,
|
|
and the value should be the multiple of 4, for faster computation""",
|
|
)
|
|
|
|
|
|
class HubertModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
cfg,
|
|
) -> None:
|
|
super().__init__()
|
|
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
|
|
self.embed = feature_enc_layers[-1][0]
|
|
|
|
self.feature_extractor = ConvFeatureExtractionModel(
|
|
conv_layers=feature_enc_layers,
|
|
dropout=0.0,
|
|
mode=cfg.extractor_mode,
|
|
conv_bias=cfg.conv_bias,
|
|
)
|
|
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
|
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate
|
|
|
|
self.post_extract_proj = (
|
|
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
|
if self.embed != cfg.encoder_embed_dim
|
|
else None
|
|
)
|
|
|
|
self.mask_prob = cfg.mask_prob
|
|
self.mask_selection = cfg.mask_selection
|
|
self.mask_other = cfg.mask_other
|
|
self.mask_length = cfg.mask_length
|
|
self.no_mask_overlap = cfg.no_mask_overlap
|
|
self.mask_min_space = cfg.mask_min_space
|
|
|
|
self.mask_channel_prob = cfg.mask_channel_prob
|
|
self.mask_channel_selection = cfg.mask_channel_selection
|
|
self.mask_channel_other = cfg.mask_channel_other
|
|
self.mask_channel_length = cfg.mask_channel_length
|
|
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
|
self.mask_channel_min_space = cfg.mask_channel_min_space
|
|
|
|
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
|
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
|
|
|
self.feature_grad_mult = cfg.feature_grad_mult
|
|
self.logit_temp = cfg.logit_temp
|
|
self.skip_masked = cfg.skip_masked
|
|
self.skip_nomask = cfg.skip_nomask
|
|
|
|
self.mask_emb = nn.Parameter(
|
|
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
|
)
|
|
|
|
self.encoder = TransformerEncoder(cfg)
|
|
self.layer_norm = LayerNorm(self.embed)
|
|
|
|
self.untie_final_proj = cfg.untie_final_proj
|
|
self.final_proj = nn.Linear(cfg.encoder_embed_dim, sum(cfg.num_classes))
|
|
|
|
# modules below are not needed during fine-tuning
|
|
self.num_classes = cfg.num_classes
|
|
self.pred_masked_weight = cfg.pred_masked_weight
|
|
self.pred_nomask_weight = cfg.pred_nomask_weight
|
|
self.loss_weights = cfg.loss_weights
|
|
|
|
def upgrade_state_dict_named(self, state_dict, name):
|
|
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
|
|
|
super().upgrade_state_dict_named(state_dict, name)
|
|
return state_dict
|
|
|
|
def apply_mask(self, x, padding_mask, target_list):
|
|
B, T, C = x.shape
|
|
if self.mask_prob > 0:
|
|
mask_indices = compute_mask_indices(
|
|
(B, T),
|
|
padding_mask,
|
|
self.mask_prob,
|
|
self.mask_length,
|
|
self.mask_selection,
|
|
self.mask_other,
|
|
min_masks=2,
|
|
no_overlap=self.no_mask_overlap,
|
|
min_space=self.mask_min_space,
|
|
)
|
|
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
|
x[mask_indices] = self.mask_emb.to(x.dtype)
|
|
else:
|
|
mask_indices = None
|
|
|
|
if self.mask_channel_prob > 0:
|
|
mask_channel_indices = compute_mask_indices(
|
|
(B, C),
|
|
None,
|
|
self.mask_channel_prob,
|
|
self.mask_channel_length,
|
|
self.mask_channel_selection,
|
|
self.mask_channel_other,
|
|
no_overlap=self.no_mask_channel_overlap,
|
|
min_space=self.mask_channel_min_space,
|
|
)
|
|
mask_channel_indices = (
|
|
torch.from_numpy(mask_channel_indices)
|
|
.to(x.device)
|
|
.unsqueeze(1)
|
|
.expand(-1, T, -1)
|
|
)
|
|
x[mask_channel_indices] = 0
|
|
|
|
return x, mask_indices
|
|
|
|
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
|
|
if self.feature_grad_mult > 0:
|
|
features = self.feature_extractor(source)
|
|
if self.feature_grad_mult != 1.0:
|
|
features = GradMultiply.apply(features, self.feature_grad_mult)
|
|
else:
|
|
with torch.no_grad():
|
|
features = self.feature_extractor(source)
|
|
return features
|
|
|
|
def forward_targets(
|
|
self,
|
|
features: torch.Tensor,
|
|
target_list: List[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# Trim features to ensure labels exist and then get aligned labels
|
|
feat_tsz = features.size(2)
|
|
targ_tsz = min([t.size(1) for t in target_list])
|
|
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
|
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
|
features = features[..., :feat_tsz]
|
|
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
|
target_list = [t[:, target_inds.long()] for t in target_list]
|
|
return features, target_list
|
|
|
|
def forward_padding_mask(
|
|
self,
|
|
features: torch.Tensor,
|
|
padding_mask: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
extra = padding_mask.size(1) % features.size(1)
|
|
if extra > 0:
|
|
padding_mask = padding_mask[:, :-extra]
|
|
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
|
padding_mask = padding_mask.all(-1)
|
|
return padding_mask
|
|
|
|
def forward(
|
|
self,
|
|
source: torch.Tensor,
|
|
target_list: Optional[List[torch.Tensor]] = None,
|
|
padding_mask: Optional[torch.Tensor] = None,
|
|
mask: bool = True,
|
|
features_only: bool = False,
|
|
output_layer: Optional[int] = None,
|
|
):
|
|
"""output layer is 1-based"""
|
|
features = self.forward_features(source)
|
|
if target_list is not None:
|
|
features, target_list = self.forward_targets(features, target_list)
|
|
|
|
features_pen = features.float().pow(2).mean()
|
|
|
|
features = features.transpose(1, 2)
|
|
features = self.layer_norm(features)
|
|
unmasked_features = features.clone()
|
|
|
|
if padding_mask is not None:
|
|
padding_mask = self.forward_padding_mask(features, padding_mask)
|
|
|
|
if self.post_extract_proj is not None:
|
|
features = self.post_extract_proj(features)
|
|
|
|
features = self.dropout_input(features)
|
|
unmasked_features = self.dropout_features(unmasked_features)
|
|
|
|
if mask:
|
|
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
|
|
else:
|
|
x = features
|
|
mask_indices = None
|
|
|
|
# feature: (B, T, D), float
|
|
# target: (B, T), long
|
|
# x: (B, T, D), float
|
|
# padding_mask: (B, T), bool
|
|
# mask_indices: (B, T), bool
|
|
x, _ = self.encoder(
|
|
x,
|
|
padding_mask=padding_mask,
|
|
layer=None if output_layer is None else output_layer - 1,
|
|
)
|
|
|
|
if features_only:
|
|
return {"x": x, "padding_mask": padding_mask, "features": features}
|
|
|
|
if not self.skip_masked:
|
|
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
|
proj_x_m = self.final_proj(x[masked_indices])
|
|
proj_x_m /= self.logit_temp
|
|
logit_m_list = [proj_x_m for _ in range(len(target_list))]
|
|
else:
|
|
logit_m_list = [None for _ in target_list]
|
|
|
|
if not self.skip_nomask:
|
|
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
|
|
proj_x_u = self.final_proj(x[nomask_indices])
|
|
proj_x_u /= self.logit_temp
|
|
logit_u_list = [proj_x_u for _ in range(len(target_list))]
|
|
else:
|
|
logit_u_list = [None for _ in target_list]
|
|
|
|
# result = {
|
|
# "logit_m_list": logit_m_list,
|
|
# "logit_u_list": logit_u_list,
|
|
# "padding_mask": padding_mask,
|
|
# "features_pen": features_pen,
|
|
# }
|
|
targ_m_list = target_list[0][masked_indices]
|
|
targ_m_list = targ_m_list.long()
|
|
targ_m_list = [targ_m_list for _ in range(len(target_list))]
|
|
|
|
targ_u_list = target_list[0][nomask_indices]
|
|
targ_u_list = targ_u_list.long()
|
|
targ_u_list = [targ_u_list for _ in range(len(target_list))]
|
|
return self.compute_loss(
|
|
logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
|
|
)
|
|
|
|
def extract_features(
|
|
self,
|
|
source: torch.Tensor,
|
|
padding_mask: Optional[torch.Tensor] = None,
|
|
mask: bool = False,
|
|
ret_conv: bool = False,
|
|
output_layer: Optional[int] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
res = self.forward(
|
|
source,
|
|
padding_mask=padding_mask,
|
|
mask=mask,
|
|
features_only=True,
|
|
output_layer=output_layer,
|
|
)
|
|
feature = res["features"] if ret_conv else res["x"]
|
|
return feature, res["padding_mask"]
|
|
|
|
def get_logits(self, net_output, is_masked=True):
|
|
if is_masked:
|
|
logits_list = net_output["logit_m_list"]
|
|
else:
|
|
logits_list = net_output["logit_u_list"]
|
|
logits_list = [x.float() for x in logits_list if x is not None]
|
|
return logits_list
|
|
|
|
def get_targets(self, net_output, is_masked=True):
|
|
logits_list = self.get_logits(net_output, is_masked)
|
|
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
|
|
return targets_list
|
|
|
|
def get_extra_losses(self, net_output):
|
|
extra_losses = []
|
|
names = []
|
|
|
|
if "features_pen" in net_output:
|
|
extra_losses.append(net_output["features_pen"])
|
|
names.append("features_pen")
|
|
|
|
return extra_losses, names
|
|
|
|
def remove_pretraining_modules(self):
|
|
self.final_proj = None
|
|
|
|
def compute_loss(
|
|
self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
|
|
):
|
|
loss = 0.0
|
|
sample_size = 0
|
|
logging_output = {}
|
|
reduce = True
|
|
reduction = "sum" if reduce else "none"
|
|
|
|
loss_m_list = []
|
|
logp_m_list = [x.float() for x in logit_m_list if x is not None]
|
|
logp_m_list = torch.cat(logp_m_list)
|
|
targ_m_list = torch.cat(targ_m_list)
|
|
|
|
loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction)
|
|
loss_m_list.append(loss_m)
|
|
logging_output[f"loss_m_0"] = loss_m.detach().item()
|
|
|
|
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
|
|
if self.pred_masked_weight > 0:
|
|
loss += self.pred_masked_weight * sum(loss_m_list)
|
|
sample_size += len(targ_m_list)
|
|
|
|
loss_u_list = []
|
|
logp_u_list = [x.float() for x in logit_u_list if x is not None]
|
|
logp_u_list = torch.cat(logp_u_list)
|
|
targ_u_list = torch.cat(targ_u_list)
|
|
|
|
loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction)
|
|
loss_u_list.append(loss_u)
|
|
logging_output[f"loss_u_0"] = loss_u.detach().item()
|
|
|
|
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
|
|
if self.pred_nomask_weight > 0:
|
|
loss += self.pred_nomask_weight * sum(loss_u_list)
|
|
sample_size += len(targ_u_list)
|
|
|
|
if self.loss_weights is not None:
|
|
extra_losses = []
|
|
names = []
|
|
extra_losses.append(features_pen)
|
|
names.append("features_pen")
|
|
if torch.is_tensor(extra_losses):
|
|
extra_losses = [extra_losses]
|
|
names = [names]
|
|
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
|
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
|
assert len(extra_losses) == len(
|
|
self.loss_weights
|
|
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
|
for p, n, coef in zip(extra_losses, names, self.loss_weights):
|
|
if coef != 0 and p is not None:
|
|
p = coef * p.float() * sample_size
|
|
loss += p
|
|
logging_output[f"loss_{n}"] = p.item()
|
|
|
|
logging_output = {
|
|
"loss": loss.item() if reduce else loss,
|
|
**logging_output,
|
|
}
|
|
|
|
# for lk in self.log_keys:
|
|
# if lk in net_output:
|
|
# logging_output[lk] = float((net_output[lk]))
|
|
|
|
def compute_correct(logits, target):
|
|
if logits.numel() == 0:
|
|
return 0, 0
|
|
else:
|
|
assert logits.dim() > 1, logits.shape
|
|
max = logits.argmax(-1) == target
|
|
min = logits.argmin(-1) == target
|
|
both = max & min
|
|
corr = max.long().sum().item() - both.long().sum().item()
|
|
count = max.numel()
|
|
return corr, count
|
|
|
|
with torch.no_grad():
|
|
corr_m, count_m = compute_correct(logp_m_list, targ_m_list)
|
|
logging_output[f"correct_m_0"] = corr_m
|
|
logging_output[f"count_m_0"] = count_m
|
|
|
|
corr_u, count_u = compute_correct(logp_u_list, targ_u_list)
|
|
logging_output[f"correct_u_0"] = corr_u
|
|
logging_output[f"count_u_0"] = count_u
|
|
|
|
return loss, sample_size, logging_output
|