Yifan Yang 87843e9382
k2SSL: a Faster and Better Framework for Self-Supervised Speech Representation Learning (#1500)
* Add k2SSL

* fix flake8

* fix for black

* fix for black

* fix for black

* Update ssl_datamodule.py

* Fix bugs in HubertDataset

* update comments

* add librilight

* add checkpoint convert script

* format

---------

Co-authored-by: yifanyeung <yifanyeung@yifanyeung.local>
Co-authored-by: zzasdf <15218404468@163.com>
2024-04-04 23:29:16 +08:00

941 lines
30 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import argparse
import logging
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import GradMultiply, LayerNorm
from wav2vec2_module import ConvFeatureExtractionModel, TransformerEncoder
def compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[torch.Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
require_same_masks: bool = True,
mask_dropout: float = 0.0,
add_masks: bool = False,
seed: Optional[int] = None,
epoch: Optional[int] = None,
indices: Optional[torch.Tensor] = None,
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_type: how to compute mask lengths
static = fixed size
uniform = sample from uniform distribution [mask_other, mask_length*2]
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
poisson = sample from possion distribution with lambda = mask length
min_masks: minimum number of masked spans
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
mask_dropout: randomly dropout this percentage of masks in each example
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
if num_mask_ver == 1:
all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(bsz):
if seed is not None and epoch is not None and indices is not None:
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
else:
seed_i = None
rng = np.random.default_rng(seed_i)
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
assert sz >= 0, sz
else:
sz = all_sz
if num_mask_ver == 1:
if padding_mask is not None:
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
num_mask = all_num_mask
elif num_mask_ver == 2:
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ rng.random()
)
num_mask = max(min_masks, num_mask)
else:
raise ValueError()
if mask_type == "static":
lengths = np.full(num_mask, mask_length)
elif mask_type == "uniform":
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
elif mask_type == "normal":
lengths = rng.normal(mask_length, mask_other, size=num_mask)
lengths = [max(1, int(round(x))) for x in lengths]
elif mask_type == "poisson":
lengths = rng.poisson(mask_length, size=num_mask)
lengths = [int(round(x)) for x in lengths]
else:
raise Exception("unknown mask selection " + mask_type)
if sum(lengths) == 0:
if mask_type == "static":
raise ValueError(f"this should never happens")
else:
lengths = [min(mask_length, sz - 1)]
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = rng.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = np.fromiter(
(e - s if e - s >= length + min_space else 0 for s, e in parts),
np.int,
)
l_sum = np.sum(lens)
if l_sum == 0:
break
probs = lens / np.sum(lens)
c = rng.choice(len(parts), p=probs)
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = np.asarray(mask_idc)
else:
if idc_select_ver == 1:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
elif idc_select_ver == 2:
mask_idc = rng.choice(sz, num_mask, replace=False)
else:
raise ValueError()
mask_idc = np.asarray(
[
mask_idc[j] + offset
for j in range(len(mask_idc))
for offset in range(lengths[j])
]
)
mask_idc = np.unique(mask_idc[mask_idc < sz])
if len(mask_idc) >= sz:
raise ValueError(
(
f"the entire sequence is masked. "
f"sz={sz}; mask_idc[mask_idc]; "
f"index={indices[i] if indices is not None else None}"
)
)
mask_idcs.append(mask_idc)
target_len = None
if require_same_masks:
if add_masks:
target_len = max([len(m) for m in mask_idcs])
else:
target_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if target_len is not None and len(mask_idc) > target_len:
mask_idc = rng.choice(mask_idc, target_len, replace=False)
mask[i, mask_idc] = True
if target_len is not None and len(mask_idc) < target_len:
unmasked = np.flatnonzero(~mask[i])
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
mask[i, to_mask] = True
if mask_dropout > 0:
masked = np.flatnonzero(mask[i])
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
to_drop = rng.choice(masked, num_holes, replace=False)
mask[i, to_drop] = False
return mask
def add_hubert_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--label-rate",
type=float,
default=50,
)
parser.add_argument(
"--sample-rate",
type=float,
default=16000,
)
parser.add_argument(
"--extractor-mode",
type=str,
default="default",
help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group
norm with d groups in the first conv block, whereas layer_norm
has layer norms in every block (meant to use with normalize=True)""",
)
parser.add_argument(
"--encoder-layers",
type=int,
default=12,
help="num encoder layers in the transformer",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
default=768,
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
default=3072,
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
default=12,
help="num encoder attention heads",
)
parser.add_argument(
"--activation-fn",
type=str,
choices=[
"relu",
"gelu",
"gelu_fast",
"gelu_accurate",
"tanh",
"linear",
],
default="gelu",
help="activation function to use",
)
parser.add_argument(
"--layer-type",
type=str,
choices=["transformer", "conformer", "trf_adp"],
default="transformer",
help="layer type in encoder",
)
# dropouts
parser.add_argument(
"--dropout",
type=float,
default=0.1,
help="dropout probability for the transformer",
)
parser.add_argument(
"--attention-dropout",
type=float,
default=0.1,
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
type=float,
default=0.0,
help="dropout probability after activation in FFN",
)
parser.add_argument(
"--encoder-layerdrop",
type=float,
default=0.0,
help="probability of dropping a tarnsformer layer",
)
parser.add_argument(
"--dropout-input",
type=float,
default=0.0,
help="dropout to apply to the input (after feat extr)",
)
parser.add_argument(
"--dropout-features",
type=float,
default=0.0,
help="dropout to apply to the features (after feat extr)",
)
parser.add_argument(
"--final-dim",
type=int,
default=0,
help="project final representations and targets to this many dimensions. set to encoder_embed_dim is <= 0",
)
parser.add_argument(
"--untie-final-proj",
type=bool,
default=False,
help="use separate projection for each target",
)
parser.add_argument(
"--layer-norm-first",
type=bool,
default=False,
help="apply layernorm first in the transformer",
)
parser.add_argument(
"--conv-feature-layers",
type=str,
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]",
)
parser.add_argument(
"--conv-bias",
type=bool,
default=False,
help="include bias in conv encoder",
)
parser.add_argument(
"--logit-temp",
type=float,
default=0.1,
help="temperature to divide logits by",
)
parser.add_argument(
"--target-glu",
type=bool,
default=False,
help="adds projection + glu to targets",
)
parser.add_argument(
"--feature-grad-mult",
type=float,
default=1.0,
help="multiply feature extractor var grads by this",
)
# masking
parser.add_argument("--mask-length", type=int, default=10, help="mask_length")
parser.add_argument(
"--mask-prob",
type=float,
default=0.65,
help="probability of replacing a token with mask",
)
parser.add_argument(
"--mask-selection",
type=str,
choices=["static", "uniform", "normal", "poisson"],
default="static",
help="how to choose mask length",
)
parser.add_argument(
"--mask-other",
type=float,
default=0,
help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh",
)
parser.add_argument(
"--no-mask-overlap",
type=bool,
default=False,
help="whether to allow masks to overlap",
)
parser.add_argument(
"--mask-min-space",
type=int,
default=1,
help="min space between spans (if no overlap is enabled)",
)
# channel masking
parser.add_argument(
"--mask-channel-length",
type=int,
default=10,
help="length of the mask for features (channels)",
)
parser.add_argument(
"--mask-channel-prob",
type=float,
default=0.0,
help="probability of replacing a feature with 0",
)
parser.add_argument(
"--mask-channel-selection",
type=str,
choices=["static", "uniform", "normal", "poisson"],
default="static",
help="how to choose mask length for channel masking",
)
parser.add_argument(
"--mask-channel-other",
type=float,
default=0,
help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh",
)
parser.add_argument(
"--no-mask-channel-overlap",
type=bool,
default=False,
help="whether to allow channel masks to overlap",
)
parser.add_argument(
"--mask-channel-min-space",
type=int,
default=1,
help="min space between spans (if no overlap is enabled)",
)
# positional embeddings
parser.add_argument(
"--conv-pos",
type=int,
default=128,
help="number of filters for convolutional positional embeddings",
)
parser.add_argument(
"--conv-pos-groups",
type=int,
default=16,
help="number of groups for convolutional positional embedding",
)
parser.add_argument(
"--conv-pos-batch-norm",
type=bool,
default=False,
help="use batch norm instead of weight norm in conv_pos (for bf16 models)",
)
parser.add_argument(
"--latent-temp",
type=float,
nargs="*",
default=[2, 0.5, 0.999995],
help="legacy (to be removed)",
)
# loss computation
parser.add_argument(
"--skip-masked",
type=bool,
default=False,
help="skip computing losses over masked frames",
)
parser.add_argument(
"--skip-nomask",
type=bool,
default=False,
help="skip computing losses over unmasked frames",
)
parser.add_argument(
"--checkpoint-activations",
type=bool,
default=False,
help="recompute activations and save memory for extra compute",
)
parser.add_argument(
"--pred-masked-weight",
type=float,
default=1,
help="weight for masked part in ssl loss",
)
parser.add_argument(
"--pred-nomask-weight",
type=float,
default=0,
help="weight for masked part in ssl loss",
)
parser.add_argument(
"--loss-weights",
type=float,
nargs="*",
default=[10],
help="weight for masked part in ssl loss",
)
# FP16 optimization
parser.add_argument(
"--required-seq-len-multiple",
type=int,
default=2,
help="pad the input to encoder such that the sequence length is divisible by multiple",
)
parser.add_argument(
"--attn-type", type=str, default="", help="if espnet use ESPNET MHA"
)
parser.add_argument(
"--pos-enc-type",
type=str,
default="abs",
help="Positional encoding type to use in conformer",
)
parser.add_argument(
"--num-classes",
type=int,
nargs="*",
default=[504],
help="""num class, a little larger than the number of cluster,
the largest is for padding,
and the value should be the multiple of 4, for faster computation""",
)
class HubertModel(nn.Module):
def __init__(
self,
cfg,
) -> None:
super().__init__()
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
self.embed = feature_enc_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias,
)
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim
else None
)
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.feature_grad_mult = cfg.feature_grad_mult
self.logit_temp = cfg.logit_temp
self.skip_masked = cfg.skip_masked
self.skip_nomask = cfg.skip_nomask
self.mask_emb = nn.Parameter(
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
)
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
self.untie_final_proj = cfg.untie_final_proj
self.final_proj = nn.Linear(cfg.encoder_embed_dim, sum(cfg.num_classes))
# modules below are not needed during fine-tuning
self.num_classes = cfg.num_classes
self.pred_masked_weight = cfg.pred_masked_weight
self.pred_nomask_weight = cfg.pred_nomask_weight
self.loss_weights = cfg.loss_weights
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
super().upgrade_state_dict_named(state_dict, name)
return state_dict
def apply_mask(self, x, padding_mask, target_list):
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb.to(x.dtype)
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
return x, mask_indices
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.feature_extractor(source)
return features
def forward_targets(
self,
features: torch.Tensor,
target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz = features.size(2)
targ_tsz = min([t.size(1) for t in target_list])
if self.feat2tar_ratio * feat_tsz > targ_tsz:
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
features = features[..., :feat_tsz]
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
target_list = [t[:, target_inds.long()] for t in target_list]
return features, target_list
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask
def forward(
self,
source: torch.Tensor,
target_list: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
):
"""output layer is 1-based"""
features = self.forward_features(source)
if target_list is not None:
features, target_list = self.forward_targets(features, target_list)
features_pen = features.float().pow(2).mean()
features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
if mask:
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
else:
x = features
mask_indices = None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x, _ = self.encoder(
x,
padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1,
)
if features_only:
return {"x": x, "padding_mask": padding_mask, "features": features}
if not self.skip_masked:
masked_indices = torch.logical_and(~padding_mask, mask_indices)
proj_x_m = self.final_proj(x[masked_indices])
proj_x_m /= self.logit_temp
logit_m_list = [proj_x_m for _ in range(len(target_list))]
else:
logit_m_list = [None for _ in target_list]
if not self.skip_nomask:
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
proj_x_u = self.final_proj(x[nomask_indices])
proj_x_u /= self.logit_temp
logit_u_list = [proj_x_u for _ in range(len(target_list))]
else:
logit_u_list = [None for _ in target_list]
# result = {
# "logit_m_list": logit_m_list,
# "logit_u_list": logit_u_list,
# "padding_mask": padding_mask,
# "features_pen": features_pen,
# }
targ_m_list = target_list[0][masked_indices]
targ_m_list = targ_m_list.long()
targ_m_list = [targ_m_list for _ in range(len(target_list))]
targ_u_list = target_list[0][nomask_indices]
targ_u_list = targ_u_list.long()
targ_u_list = [targ_u_list for _ in range(len(target_list))]
return self.compute_loss(
logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
)
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
res = self.forward(
source,
padding_mask=padding_mask,
mask=mask,
features_only=True,
output_layer=output_layer,
)
feature = res["features"] if ret_conv else res["x"]
return feature, res["padding_mask"]
def get_logits(self, net_output, is_masked=True):
if is_masked:
logits_list = net_output["logit_m_list"]
else:
logits_list = net_output["logit_u_list"]
logits_list = [x.float() for x in logits_list if x is not None]
return logits_list
def get_targets(self, net_output, is_masked=True):
logits_list = self.get_logits(net_output, is_masked)
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
return targets_list
def get_extra_losses(self, net_output):
extra_losses = []
names = []
if "features_pen" in net_output:
extra_losses.append(net_output["features_pen"])
names.append("features_pen")
return extra_losses, names
def remove_pretraining_modules(self):
self.final_proj = None
def compute_loss(
self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
):
loss = 0.0
sample_size = 0
logging_output = {}
reduce = True
reduction = "sum" if reduce else "none"
loss_m_list = []
logp_m_list = [x.float() for x in logit_m_list if x is not None]
logp_m_list = torch.cat(logp_m_list)
targ_m_list = torch.cat(targ_m_list)
loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction)
loss_m_list.append(loss_m)
logging_output[f"loss_m_0"] = loss_m.detach().item()
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
if self.pred_masked_weight > 0:
loss += self.pred_masked_weight * sum(loss_m_list)
sample_size += len(targ_m_list)
loss_u_list = []
logp_u_list = [x.float() for x in logit_u_list if x is not None]
logp_u_list = torch.cat(logp_u_list)
targ_u_list = torch.cat(targ_u_list)
loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction)
loss_u_list.append(loss_u)
logging_output[f"loss_u_0"] = loss_u.detach().item()
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
if self.pred_nomask_weight > 0:
loss += self.pred_nomask_weight * sum(loss_u_list)
sample_size += len(targ_u_list)
if self.loss_weights is not None:
extra_losses = []
names = []
extra_losses.append(features_pen)
names.append("features_pen")
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
logging_output[f"loss_{n}"] = p.item()
logging_output = {
"loss": loss.item() if reduce else loss,
**logging_output,
}
# for lk in self.log_keys:
# if lk in net_output:
# logging_output[lk] = float((net_output[lk]))
def compute_correct(logits, target):
if logits.numel() == 0:
return 0, 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == target
min = logits.argmin(-1) == target
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
with torch.no_grad():
corr_m, count_m = compute_correct(logp_m_list, targ_m_list)
logging_output[f"correct_m_0"] = corr_m
logging_output[f"count_m_0"] = count_m
corr_u, count_u = compute_correct(logp_u_list, targ_u_list)
logging_output[f"correct_u_0"] = corr_u
logging_output[f"count_u_0"] = count_u
return loss, sample_size, logging_output