mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
deleted k2SSL files
This commit is contained in:
parent
b793c5e958
commit
04a81f53f0
@ -1,42 +0,0 @@
|
|||||||
# 預訓練參數設置
|
|
||||||
train_params = {
|
|
||||||
# 模型參數
|
|
||||||
"label_rate": 50,
|
|
||||||
"sample_rate": 16000,
|
|
||||||
"extractor_mode": "default",
|
|
||||||
"conv_feature_layers": "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
|
|
||||||
"conv_bias": False,
|
|
||||||
"feature_grad_mult": 1.0,
|
|
||||||
|
|
||||||
# 掩碼參數
|
|
||||||
"mask_length": 10,
|
|
||||||
"mask_prob": 0.65,
|
|
||||||
"mask_selection": "static",
|
|
||||||
"mask_other": 0,
|
|
||||||
"no_mask_overlap": False,
|
|
||||||
"mask_min_space": 1,
|
|
||||||
|
|
||||||
# 通道掩碼參數
|
|
||||||
"mask_channel_length": 10,
|
|
||||||
"mask_channel_prob": 0.0,
|
|
||||||
"mask_channel_selection": "static",
|
|
||||||
"mask_channel_other": 0,
|
|
||||||
"no_mask_channel_overlap": False,
|
|
||||||
"mask_channel_min_space": 1,
|
|
||||||
|
|
||||||
# 損失計算參數
|
|
||||||
"skip_masked": False,
|
|
||||||
"skip_nomask": False,
|
|
||||||
"pred_masked_weight": 1,
|
|
||||||
"pred_nomask_weight": 0,
|
|
||||||
"loss_weights": [10],
|
|
||||||
"checkpoint_activations": False,
|
|
||||||
|
|
||||||
# 其他參數
|
|
||||||
"dropout_input": 0.0,
|
|
||||||
"dropout_features": 0.0,
|
|
||||||
"num_classes": [504],
|
|
||||||
"untie_final_proj": False,
|
|
||||||
"required_seq_len_multiple": 2,
|
|
||||||
"logit_temp": 0.1,
|
|
||||||
}
|
|
@ -1,17 +0,0 @@
|
|||||||
def get_zipformer_base_config():
|
|
||||||
return {
|
|
||||||
"output_downsampling_factor": 1,
|
|
||||||
"downsampling_factor": (1, 2, 4, 8, 4, 2),
|
|
||||||
"encoder_dim": (192, 256, 384, 512, 384, 256),
|
|
||||||
"num_encoder_layers": (2, 2, 3, 4, 3, 2),
|
|
||||||
"encoder_unmasked_dim": (192, 192, 256, 256, 256, 192),
|
|
||||||
"query_head_dim": 32,
|
|
||||||
"pos_head_dim": 4,
|
|
||||||
"value_head_dim": 12,
|
|
||||||
"pos_dim": 48,
|
|
||||||
"num_heads": (4, 4, 4, 8, 4, 4),
|
|
||||||
"feedforward_dim": (512, 768, 1024, 1536, 1024, 768),
|
|
||||||
"cnn_module_kernel": (31, 31, 15, 15, 15, 31),
|
|
||||||
"dropout": 0.1,
|
|
||||||
"warmup_batches": 4000.0,
|
|
||||||
}
|
|
@ -1,601 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
#
|
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
# of this software and associated documentation files (the "Software"), to deal
|
|
||||||
# in the Software without restriction, including without limitation the rights
|
|
||||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
# copies of the Software, and to permit persons to whom the Software is
|
|
||||||
# furnished to do so, subject to the following conditions:
|
|
||||||
#
|
|
||||||
# The above copyright notice and this permission notice shall be included in all
|
|
||||||
# copies or substantial portions of the Software.
|
|
||||||
#
|
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
# SOFTWARE.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from scaling import ScheduledFloat
|
|
||||||
from utils import GradMultiply, LayerNorm
|
|
||||||
from wav2vec2_module import ConvFeatureExtractionModel
|
|
||||||
from zipformer import Zipformer2
|
|
||||||
|
|
||||||
|
|
||||||
def compute_mask_indices(
|
|
||||||
shape: Tuple[int, int],
|
|
||||||
padding_mask: Optional[torch.Tensor],
|
|
||||||
mask_prob: float,
|
|
||||||
mask_length: int,
|
|
||||||
mask_type: str = "static",
|
|
||||||
mask_other: float = 0.0,
|
|
||||||
min_masks: int = 0,
|
|
||||||
no_overlap: bool = False,
|
|
||||||
min_space: int = 0,
|
|
||||||
require_same_masks: bool = True,
|
|
||||||
mask_dropout: float = 0.0,
|
|
||||||
add_masks: bool = False,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
epoch: Optional[int] = None,
|
|
||||||
indices: Optional[torch.Tensor] = None,
|
|
||||||
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
|
|
||||||
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Computes random mask spans for a given shape
|
|
||||||
|
|
||||||
Args:
|
|
||||||
shape: the the shape for which to compute masks.
|
|
||||||
should be of size 2 where first element is batch size and 2nd is timesteps
|
|
||||||
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
|
||||||
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
|
||||||
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
|
||||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
|
||||||
mask_type: how to compute mask lengths
|
|
||||||
static = fixed size
|
|
||||||
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
|
||||||
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
|
||||||
poisson = sample from possion distribution with lambda = mask length
|
|
||||||
min_masks: minimum number of masked spans
|
|
||||||
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
|
||||||
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
|
||||||
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
|
||||||
mask_dropout: randomly dropout this percentage of masks in each example
|
|
||||||
"""
|
|
||||||
|
|
||||||
bsz, all_sz = shape
|
|
||||||
mask = np.full((bsz, all_sz), False)
|
|
||||||
|
|
||||||
if num_mask_ver == 1:
|
|
||||||
all_num_mask = int(
|
|
||||||
# add a random number for probabilistic rounding
|
|
||||||
mask_prob * all_sz / float(mask_length)
|
|
||||||
+ np.random.rand()
|
|
||||||
)
|
|
||||||
all_num_mask = max(min_masks, all_num_mask)
|
|
||||||
|
|
||||||
mask_idcs = []
|
|
||||||
for i in range(bsz):
|
|
||||||
if seed is not None and epoch is not None and indices is not None:
|
|
||||||
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
|
|
||||||
else:
|
|
||||||
seed_i = None
|
|
||||||
|
|
||||||
rng = np.random.default_rng(seed_i)
|
|
||||||
|
|
||||||
if padding_mask is not None:
|
|
||||||
sz = all_sz - padding_mask[i].long().sum().item()
|
|
||||||
assert sz >= 0, sz
|
|
||||||
else:
|
|
||||||
sz = all_sz
|
|
||||||
|
|
||||||
if num_mask_ver == 1:
|
|
||||||
if padding_mask is not None:
|
|
||||||
num_mask = int(
|
|
||||||
# add a random number for probabilistic rounding
|
|
||||||
mask_prob * sz / float(mask_length)
|
|
||||||
+ np.random.rand()
|
|
||||||
)
|
|
||||||
num_mask = max(min_masks, num_mask)
|
|
||||||
else:
|
|
||||||
num_mask = all_num_mask
|
|
||||||
elif num_mask_ver == 2:
|
|
||||||
num_mask = int(
|
|
||||||
# add a random number for probabilistic rounding
|
|
||||||
mask_prob * sz / float(mask_length)
|
|
||||||
+ rng.random()
|
|
||||||
)
|
|
||||||
num_mask = max(min_masks, num_mask)
|
|
||||||
else:
|
|
||||||
raise ValueError()
|
|
||||||
|
|
||||||
if mask_type == "static":
|
|
||||||
lengths = np.full(num_mask, mask_length)
|
|
||||||
elif mask_type == "uniform":
|
|
||||||
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
|
||||||
elif mask_type == "normal":
|
|
||||||
lengths = rng.normal(mask_length, mask_other, size=num_mask)
|
|
||||||
lengths = [max(1, int(round(x))) for x in lengths]
|
|
||||||
elif mask_type == "poisson":
|
|
||||||
lengths = rng.poisson(mask_length, size=num_mask)
|
|
||||||
lengths = [int(round(x)) for x in lengths]
|
|
||||||
else:
|
|
||||||
raise Exception("unknown mask selection " + mask_type)
|
|
||||||
|
|
||||||
if sum(lengths) == 0:
|
|
||||||
if mask_type == "static":
|
|
||||||
raise ValueError(f"this should never happens")
|
|
||||||
else:
|
|
||||||
lengths = [min(mask_length, sz - 1)]
|
|
||||||
|
|
||||||
if no_overlap:
|
|
||||||
mask_idc = []
|
|
||||||
|
|
||||||
def arrange(s, e, length, keep_length):
|
|
||||||
span_start = rng.randint(s, e - length)
|
|
||||||
mask_idc.extend(span_start + i for i in range(length))
|
|
||||||
|
|
||||||
new_parts = []
|
|
||||||
if span_start - s - min_space >= keep_length:
|
|
||||||
new_parts.append((s, span_start - min_space + 1))
|
|
||||||
if e - span_start - length - min_space > keep_length:
|
|
||||||
new_parts.append((span_start + length + min_space, e))
|
|
||||||
return new_parts
|
|
||||||
|
|
||||||
parts = [(0, sz)]
|
|
||||||
min_length = min(lengths)
|
|
||||||
for length in sorted(lengths, reverse=True):
|
|
||||||
lens = np.fromiter(
|
|
||||||
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
|
||||||
np.int,
|
|
||||||
)
|
|
||||||
l_sum = np.sum(lens)
|
|
||||||
if l_sum == 0:
|
|
||||||
break
|
|
||||||
probs = lens / np.sum(lens)
|
|
||||||
c = rng.choice(len(parts), p=probs)
|
|
||||||
s, e = parts.pop(c)
|
|
||||||
parts.extend(arrange(s, e, length, min_length))
|
|
||||||
mask_idc = np.asarray(mask_idc)
|
|
||||||
else:
|
|
||||||
if idc_select_ver == 1:
|
|
||||||
min_len = min(lengths)
|
|
||||||
if sz - min_len <= num_mask:
|
|
||||||
min_len = sz - num_mask - 1
|
|
||||||
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
|
|
||||||
elif idc_select_ver == 2:
|
|
||||||
mask_idc = rng.choice(sz, num_mask, replace=False)
|
|
||||||
else:
|
|
||||||
raise ValueError()
|
|
||||||
|
|
||||||
mask_idc = np.asarray(
|
|
||||||
[
|
|
||||||
mask_idc[j] + offset
|
|
||||||
for j in range(len(mask_idc))
|
|
||||||
for offset in range(lengths[j])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
mask_idc = np.unique(mask_idc[mask_idc < sz])
|
|
||||||
if len(mask_idc) >= sz:
|
|
||||||
raise ValueError(
|
|
||||||
(
|
|
||||||
f"the entire sequence is masked. "
|
|
||||||
f"sz={sz}; mask_idc[mask_idc]; "
|
|
||||||
f"index={indices[i] if indices is not None else None}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
mask_idcs.append(mask_idc)
|
|
||||||
|
|
||||||
target_len = None
|
|
||||||
if require_same_masks:
|
|
||||||
if add_masks:
|
|
||||||
target_len = max([len(m) for m in mask_idcs])
|
|
||||||
else:
|
|
||||||
target_len = min([len(m) for m in mask_idcs])
|
|
||||||
|
|
||||||
for i, mask_idc in enumerate(mask_idcs):
|
|
||||||
if target_len is not None and len(mask_idc) > target_len:
|
|
||||||
mask_idc = rng.choice(mask_idc, target_len, replace=False)
|
|
||||||
|
|
||||||
mask[i, mask_idc] = True
|
|
||||||
|
|
||||||
if target_len is not None and len(mask_idc) < target_len:
|
|
||||||
unmasked = np.flatnonzero(~mask[i])
|
|
||||||
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
|
|
||||||
mask[i, to_mask] = True
|
|
||||||
|
|
||||||
if mask_dropout > 0:
|
|
||||||
masked = np.flatnonzero(mask[i])
|
|
||||||
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
|
|
||||||
to_drop = rng.choice(masked, num_holes, replace=False)
|
|
||||||
mask[i, to_drop] = False
|
|
||||||
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def _to_int_tuple(s: str):
|
|
||||||
return tuple(map(int, s.split(",")))
|
|
||||||
|
|
||||||
|
|
||||||
class HubertModel(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cfg,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
|
|
||||||
self.embed = feature_enc_layers[-1][0]
|
|
||||||
|
|
||||||
self.feature_extractor = ConvFeatureExtractionModel(
|
|
||||||
conv_layers=feature_enc_layers,
|
|
||||||
dropout=0.0,
|
|
||||||
mode=cfg.extractor_mode,
|
|
||||||
conv_bias=cfg.conv_bias,
|
|
||||||
)
|
|
||||||
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
|
||||||
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate
|
|
||||||
encoder_input_dim = _to_int_tuple(cfg.encoder_dim)[0]
|
|
||||||
encoder_output_dim = max(_to_int_tuple(cfg.encoder_dim))
|
|
||||||
self.post_extract_proj = (
|
|
||||||
nn.Linear(self.embed, encoder_input_dim)
|
|
||||||
if self.embed != encoder_input_dim
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.mask_prob = cfg.mask_prob
|
|
||||||
self.mask_selection = cfg.mask_selection
|
|
||||||
self.mask_other = cfg.mask_other
|
|
||||||
self.mask_length = cfg.mask_length
|
|
||||||
self.no_mask_overlap = cfg.no_mask_overlap
|
|
||||||
self.mask_min_space = cfg.mask_min_space
|
|
||||||
|
|
||||||
self.mask_channel_prob = cfg.mask_channel_prob
|
|
||||||
self.mask_channel_selection = cfg.mask_channel_selection
|
|
||||||
self.mask_channel_other = cfg.mask_channel_other
|
|
||||||
self.mask_channel_length = cfg.mask_channel_length
|
|
||||||
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
|
||||||
self.mask_channel_min_space = cfg.mask_channel_min_space
|
|
||||||
|
|
||||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
|
||||||
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
|
||||||
|
|
||||||
self.feature_grad_mult = cfg.feature_grad_mult
|
|
||||||
self.logit_temp = cfg.logit_temp
|
|
||||||
self.skip_masked = cfg.skip_masked
|
|
||||||
self.skip_nomask = cfg.skip_nomask
|
|
||||||
|
|
||||||
self.mask_emb = nn.Parameter(torch.FloatTensor(encoder_input_dim).uniform_())
|
|
||||||
|
|
||||||
self.encoder = Zipformer2(
|
|
||||||
output_downsampling_factor=1,
|
|
||||||
downsampling_factor=_to_int_tuple(cfg.downsampling_factor),
|
|
||||||
num_encoder_layers=_to_int_tuple(cfg.num_encoder_layers),
|
|
||||||
encoder_dim=_to_int_tuple(cfg.encoder_dim),
|
|
||||||
encoder_unmasked_dim=_to_int_tuple(cfg.encoder_unmasked_dim),
|
|
||||||
query_head_dim=_to_int_tuple(cfg.query_head_dim),
|
|
||||||
pos_head_dim=_to_int_tuple(cfg.pos_head_dim),
|
|
||||||
value_head_dim=_to_int_tuple(cfg.value_head_dim),
|
|
||||||
pos_dim=cfg.pos_dim,
|
|
||||||
num_heads=_to_int_tuple(cfg.num_heads),
|
|
||||||
feedforward_dim=_to_int_tuple(cfg.feedforward_dim),
|
|
||||||
cnn_module_kernel=_to_int_tuple(cfg.cnn_module_kernel),
|
|
||||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
|
||||||
warmup_batches=4000.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.layer_norm = LayerNorm(self.embed)
|
|
||||||
|
|
||||||
self.untie_final_proj = cfg.untie_final_proj
|
|
||||||
self.final_proj = nn.Linear(encoder_output_dim, sum(cfg.num_classes))
|
|
||||||
|
|
||||||
# modules below are not needed during fine-tuning
|
|
||||||
self.num_classes = cfg.num_classes
|
|
||||||
self.pred_masked_weight = cfg.pred_masked_weight
|
|
||||||
self.pred_nomask_weight = cfg.pred_nomask_weight
|
|
||||||
self.loss_weights = cfg.loss_weights
|
|
||||||
|
|
||||||
def upgrade_state_dict_named(self, state_dict, name):
|
|
||||||
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
|
||||||
|
|
||||||
super().upgrade_state_dict_named(state_dict, name)
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
def apply_mask(self, x, padding_mask, target_list):
|
|
||||||
B, T, C = x.shape
|
|
||||||
if self.mask_prob > 0:
|
|
||||||
mask_indices = compute_mask_indices(
|
|
||||||
(B, T),
|
|
||||||
padding_mask,
|
|
||||||
self.mask_prob,
|
|
||||||
self.mask_length,
|
|
||||||
self.mask_selection,
|
|
||||||
self.mask_other,
|
|
||||||
min_masks=2,
|
|
||||||
no_overlap=self.no_mask_overlap,
|
|
||||||
min_space=self.mask_min_space,
|
|
||||||
)
|
|
||||||
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
|
||||||
x[mask_indices] = self.mask_emb.to(x.dtype)
|
|
||||||
else:
|
|
||||||
mask_indices = None
|
|
||||||
|
|
||||||
if self.mask_channel_prob > 0:
|
|
||||||
mask_channel_indices = compute_mask_indices(
|
|
||||||
(B, C),
|
|
||||||
None,
|
|
||||||
self.mask_channel_prob,
|
|
||||||
self.mask_channel_length,
|
|
||||||
self.mask_channel_selection,
|
|
||||||
self.mask_channel_other,
|
|
||||||
no_overlap=self.no_mask_channel_overlap,
|
|
||||||
min_space=self.mask_channel_min_space,
|
|
||||||
)
|
|
||||||
mask_channel_indices = (
|
|
||||||
torch.from_numpy(mask_channel_indices)
|
|
||||||
.to(x.device)
|
|
||||||
.unsqueeze(1)
|
|
||||||
.expand(-1, T, -1)
|
|
||||||
)
|
|
||||||
x[mask_channel_indices] = 0
|
|
||||||
|
|
||||||
return x, mask_indices
|
|
||||||
|
|
||||||
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
|
|
||||||
if self.feature_grad_mult > 0:
|
|
||||||
features = self.feature_extractor(source)
|
|
||||||
if self.feature_grad_mult != 1.0:
|
|
||||||
features = GradMultiply.apply(features, self.feature_grad_mult)
|
|
||||||
else:
|
|
||||||
with torch.no_grad():
|
|
||||||
features = self.feature_extractor(source)
|
|
||||||
return features
|
|
||||||
|
|
||||||
def forward_targets(
|
|
||||||
self,
|
|
||||||
features: torch.Tensor,
|
|
||||||
target_list: List[torch.Tensor],
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# Trim features to ensure labels exist and then get aligned labels
|
|
||||||
feat_tsz = features.size(2)
|
|
||||||
targ_tsz = min([t.size(1) for t in target_list])
|
|
||||||
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
|
||||||
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
|
||||||
features = features[..., :feat_tsz]
|
|
||||||
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
|
||||||
target_list = [t[:, target_inds.long()] for t in target_list]
|
|
||||||
return features, target_list
|
|
||||||
|
|
||||||
def forward_padding_mask(
|
|
||||||
self,
|
|
||||||
features: torch.Tensor,
|
|
||||||
padding_mask: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
extra = padding_mask.size(1) % features.size(1)
|
|
||||||
if extra > 0:
|
|
||||||
padding_mask = padding_mask[:, :-extra]
|
|
||||||
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
|
||||||
padding_mask = padding_mask.all(-1)
|
|
||||||
return padding_mask
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
source: torch.Tensor,
|
|
||||||
target_list: Optional[List[torch.Tensor]] = None,
|
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
|
||||||
mask: bool = True,
|
|
||||||
features_only: bool = False,
|
|
||||||
output_layer: Optional[int] = None,
|
|
||||||
):
|
|
||||||
"""output layer is 1-based"""
|
|
||||||
features = self.forward_features(source)
|
|
||||||
if target_list is not None:
|
|
||||||
features, target_list = self.forward_targets(features, target_list)
|
|
||||||
|
|
||||||
features_pen = features.float().pow(2).mean()
|
|
||||||
|
|
||||||
features = features.transpose(1, 2)
|
|
||||||
features = self.layer_norm(features)
|
|
||||||
unmasked_features = features.clone()
|
|
||||||
|
|
||||||
if padding_mask is not None:
|
|
||||||
padding_mask = self.forward_padding_mask(features, padding_mask)
|
|
||||||
|
|
||||||
if self.post_extract_proj is not None:
|
|
||||||
features = self.post_extract_proj(features)
|
|
||||||
|
|
||||||
features = self.dropout_input(features)
|
|
||||||
unmasked_features = self.dropout_features(unmasked_features)
|
|
||||||
|
|
||||||
if mask:
|
|
||||||
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
|
|
||||||
else:
|
|
||||||
x = features
|
|
||||||
mask_indices = None
|
|
||||||
|
|
||||||
# feature: (B, T, D), float
|
|
||||||
# target: (B, T), long
|
|
||||||
# x: (B, T, D), float -> (T, B, D), float
|
|
||||||
# padding_mask: (B, T), bool
|
|
||||||
# mask_indices: (B, T), bool
|
|
||||||
x = x.transpose(0, 1)
|
|
||||||
x, x_lens = self.encoder(x, (~padding_mask).sum(dim=-1))
|
|
||||||
x = x.transpose(0, 1)
|
|
||||||
|
|
||||||
if features_only:
|
|
||||||
return {"x": x, "padding_mask": padding_mask, "features": features}
|
|
||||||
|
|
||||||
if not self.skip_masked:
|
|
||||||
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
|
||||||
proj_x_m = self.final_proj(x[masked_indices])
|
|
||||||
proj_x_m /= self.logit_temp
|
|
||||||
logit_m_list = [proj_x_m for _ in range(len(target_list))]
|
|
||||||
else:
|
|
||||||
logit_m_list = [None for _ in target_list]
|
|
||||||
|
|
||||||
if not self.skip_nomask:
|
|
||||||
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
|
|
||||||
proj_x_u = self.final_proj(x[nomask_indices])
|
|
||||||
proj_x_u /= self.logit_temp
|
|
||||||
logit_u_list = [proj_x_u for _ in range(len(target_list))]
|
|
||||||
else:
|
|
||||||
logit_u_list = [None for _ in target_list]
|
|
||||||
|
|
||||||
# result = {
|
|
||||||
# "logit_m_list": logit_m_list,
|
|
||||||
# "logit_u_list": logit_u_list,
|
|
||||||
# "padding_mask": padding_mask,
|
|
||||||
# "features_pen": features_pen,
|
|
||||||
# }
|
|
||||||
targ_m_list = target_list[0][masked_indices]
|
|
||||||
targ_m_list = targ_m_list.long()
|
|
||||||
targ_m_list = [targ_m_list for _ in range(len(target_list))]
|
|
||||||
|
|
||||||
targ_u_list = target_list[0][nomask_indices]
|
|
||||||
targ_u_list = targ_u_list.long()
|
|
||||||
targ_u_list = [targ_u_list for _ in range(len(target_list))]
|
|
||||||
return self.compute_loss(
|
|
||||||
logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
|
|
||||||
)
|
|
||||||
|
|
||||||
def extract_features(
|
|
||||||
self,
|
|
||||||
source: torch.Tensor,
|
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
|
||||||
mask: bool = False,
|
|
||||||
ret_conv: bool = False,
|
|
||||||
output_layer: Optional[int] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
res = self.forward(
|
|
||||||
source,
|
|
||||||
padding_mask=padding_mask,
|
|
||||||
mask=mask,
|
|
||||||
features_only=True,
|
|
||||||
output_layer=output_layer,
|
|
||||||
)
|
|
||||||
feature = res["features"] if ret_conv else res["x"]
|
|
||||||
return feature, res["padding_mask"]
|
|
||||||
|
|
||||||
def get_logits(self, net_output, is_masked=True):
|
|
||||||
if is_masked:
|
|
||||||
logits_list = net_output["logit_m_list"]
|
|
||||||
else:
|
|
||||||
logits_list = net_output["logit_u_list"]
|
|
||||||
logits_list = [x.float() for x in logits_list if x is not None]
|
|
||||||
return logits_list
|
|
||||||
|
|
||||||
def get_targets(self, net_output, is_masked=True):
|
|
||||||
logits_list = self.get_logits(net_output, is_masked)
|
|
||||||
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
|
|
||||||
return targets_list
|
|
||||||
|
|
||||||
def get_extra_losses(self, net_output):
|
|
||||||
extra_losses = []
|
|
||||||
names = []
|
|
||||||
|
|
||||||
if "features_pen" in net_output:
|
|
||||||
extra_losses.append(net_output["features_pen"])
|
|
||||||
names.append("features_pen")
|
|
||||||
|
|
||||||
return extra_losses, names
|
|
||||||
|
|
||||||
def remove_pretraining_modules(self):
|
|
||||||
self.final_proj = None
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
|
|
||||||
):
|
|
||||||
loss = 0.0
|
|
||||||
sample_size = 0
|
|
||||||
logging_output = {}
|
|
||||||
reduce = True
|
|
||||||
reduction = "sum" if reduce else "none"
|
|
||||||
|
|
||||||
loss_m_list = []
|
|
||||||
logp_m_list = [x.float() for x in logit_m_list if x is not None]
|
|
||||||
logp_m_list = torch.cat(logp_m_list)
|
|
||||||
targ_m_list = torch.cat(targ_m_list)
|
|
||||||
|
|
||||||
loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction)
|
|
||||||
loss_m_list.append(loss_m)
|
|
||||||
logging_output[f"loss_m_0"] = loss_m.detach().item()
|
|
||||||
|
|
||||||
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
|
|
||||||
if self.pred_masked_weight > 0:
|
|
||||||
loss += self.pred_masked_weight * sum(loss_m_list)
|
|
||||||
sample_size += len(targ_m_list)
|
|
||||||
|
|
||||||
loss_u_list = []
|
|
||||||
logp_u_list = [x.float() for x in logit_u_list if x is not None]
|
|
||||||
logp_u_list = torch.cat(logp_u_list)
|
|
||||||
targ_u_list = torch.cat(targ_u_list)
|
|
||||||
|
|
||||||
loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction)
|
|
||||||
loss_u_list.append(loss_u)
|
|
||||||
logging_output[f"loss_u_0"] = loss_u.detach().item()
|
|
||||||
|
|
||||||
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
|
|
||||||
if self.pred_nomask_weight > 0:
|
|
||||||
loss += self.pred_nomask_weight * sum(loss_u_list)
|
|
||||||
sample_size += len(targ_u_list)
|
|
||||||
|
|
||||||
if self.loss_weights is not None:
|
|
||||||
extra_losses = []
|
|
||||||
names = []
|
|
||||||
extra_losses.append(features_pen)
|
|
||||||
names.append("features_pen")
|
|
||||||
if torch.is_tensor(extra_losses):
|
|
||||||
extra_losses = [extra_losses]
|
|
||||||
names = [names]
|
|
||||||
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
|
||||||
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
|
||||||
assert len(extra_losses) == len(
|
|
||||||
self.loss_weights
|
|
||||||
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
|
||||||
for p, n, coef in zip(extra_losses, names, self.loss_weights):
|
|
||||||
if coef != 0 and p is not None:
|
|
||||||
p = coef * p.float() * sample_size
|
|
||||||
loss += p
|
|
||||||
logging_output[f"loss_{n}"] = p.item()
|
|
||||||
|
|
||||||
logging_output = {
|
|
||||||
"loss": loss.item() if reduce else loss,
|
|
||||||
**logging_output,
|
|
||||||
}
|
|
||||||
|
|
||||||
# for lk in self.log_keys:
|
|
||||||
# if lk in net_output:
|
|
||||||
# logging_output[lk] = float((net_output[lk]))
|
|
||||||
|
|
||||||
def compute_correct(logits, target):
|
|
||||||
if logits.numel() == 0:
|
|
||||||
return 0, 0
|
|
||||||
else:
|
|
||||||
assert logits.dim() > 1, logits.shape
|
|
||||||
max = logits.argmax(-1) == target
|
|
||||||
min = logits.argmin(-1) == target
|
|
||||||
both = max & min
|
|
||||||
corr = max.long().sum().item() - both.long().sum().item()
|
|
||||||
count = max.numel()
|
|
||||||
return corr, count
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
corr_m, count_m = compute_correct(logp_m_list, targ_m_list)
|
|
||||||
logging_output[f"correct_m_0"] = corr_m
|
|
||||||
logging_output[f"count_m_0"] = count_m
|
|
||||||
|
|
||||||
corr_u, count_u = compute_correct(logp_u_list, targ_u_list)
|
|
||||||
logging_output[f"correct_u_0"] = corr_u
|
|
||||||
logging_output[f"count_u_0"] = count_u
|
|
||||||
|
|
||||||
return loss, sample_size, logging_output
|
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user