deleted k2SSL files

This commit is contained in:
RedSheep 2025-03-12 12:08:03 +08:00
parent b793c5e958
commit 04a81f53f0
4 changed files with 0 additions and 2040 deletions

View File

@ -1,42 +0,0 @@
# 預訓練參數設置
train_params = {
# 模型參數
"label_rate": 50,
"sample_rate": 16000,
"extractor_mode": "default",
"conv_feature_layers": "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
"conv_bias": False,
"feature_grad_mult": 1.0,
# 掩碼參數
"mask_length": 10,
"mask_prob": 0.65,
"mask_selection": "static",
"mask_other": 0,
"no_mask_overlap": False,
"mask_min_space": 1,
# 通道掩碼參數
"mask_channel_length": 10,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_channel_other": 0,
"no_mask_channel_overlap": False,
"mask_channel_min_space": 1,
# 損失計算參數
"skip_masked": False,
"skip_nomask": False,
"pred_masked_weight": 1,
"pred_nomask_weight": 0,
"loss_weights": [10],
"checkpoint_activations": False,
# 其他參數
"dropout_input": 0.0,
"dropout_features": 0.0,
"num_classes": [504],
"untie_final_proj": False,
"required_seq_len_multiple": 2,
"logit_temp": 0.1,
}

View File

@ -1,17 +0,0 @@
def get_zipformer_base_config():
return {
"output_downsampling_factor": 1,
"downsampling_factor": (1, 2, 4, 8, 4, 2),
"encoder_dim": (192, 256, 384, 512, 384, 256),
"num_encoder_layers": (2, 2, 3, 4, 3, 2),
"encoder_unmasked_dim": (192, 192, 256, 256, 256, 192),
"query_head_dim": 32,
"pos_head_dim": 4,
"value_head_dim": 12,
"pos_dim": 48,
"num_heads": (4, 4, 4, 8, 4, 4),
"feedforward_dim": (512, 768, 1024, 1536, 1024, 768),
"cnn_module_kernel": (31, 31, 15, 15, 15, 31),
"dropout": 0.1,
"warmup_batches": 4000.0,
}

View File

@ -1,601 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import argparse
import logging
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import ScheduledFloat
from utils import GradMultiply, LayerNorm
from wav2vec2_module import ConvFeatureExtractionModel
from zipformer import Zipformer2
def compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[torch.Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
require_same_masks: bool = True,
mask_dropout: float = 0.0,
add_masks: bool = False,
seed: Optional[int] = None,
epoch: Optional[int] = None,
indices: Optional[torch.Tensor] = None,
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_type: how to compute mask lengths
static = fixed size
uniform = sample from uniform distribution [mask_other, mask_length*2]
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
poisson = sample from possion distribution with lambda = mask length
min_masks: minimum number of masked spans
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
mask_dropout: randomly dropout this percentage of masks in each example
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
if num_mask_ver == 1:
all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(bsz):
if seed is not None and epoch is not None and indices is not None:
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
else:
seed_i = None
rng = np.random.default_rng(seed_i)
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
assert sz >= 0, sz
else:
sz = all_sz
if num_mask_ver == 1:
if padding_mask is not None:
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
num_mask = all_num_mask
elif num_mask_ver == 2:
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ rng.random()
)
num_mask = max(min_masks, num_mask)
else:
raise ValueError()
if mask_type == "static":
lengths = np.full(num_mask, mask_length)
elif mask_type == "uniform":
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
elif mask_type == "normal":
lengths = rng.normal(mask_length, mask_other, size=num_mask)
lengths = [max(1, int(round(x))) for x in lengths]
elif mask_type == "poisson":
lengths = rng.poisson(mask_length, size=num_mask)
lengths = [int(round(x)) for x in lengths]
else:
raise Exception("unknown mask selection " + mask_type)
if sum(lengths) == 0:
if mask_type == "static":
raise ValueError(f"this should never happens")
else:
lengths = [min(mask_length, sz - 1)]
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = rng.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = np.fromiter(
(e - s if e - s >= length + min_space else 0 for s, e in parts),
np.int,
)
l_sum = np.sum(lens)
if l_sum == 0:
break
probs = lens / np.sum(lens)
c = rng.choice(len(parts), p=probs)
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = np.asarray(mask_idc)
else:
if idc_select_ver == 1:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
elif idc_select_ver == 2:
mask_idc = rng.choice(sz, num_mask, replace=False)
else:
raise ValueError()
mask_idc = np.asarray(
[
mask_idc[j] + offset
for j in range(len(mask_idc))
for offset in range(lengths[j])
]
)
mask_idc = np.unique(mask_idc[mask_idc < sz])
if len(mask_idc) >= sz:
raise ValueError(
(
f"the entire sequence is masked. "
f"sz={sz}; mask_idc[mask_idc]; "
f"index={indices[i] if indices is not None else None}"
)
)
mask_idcs.append(mask_idc)
target_len = None
if require_same_masks:
if add_masks:
target_len = max([len(m) for m in mask_idcs])
else:
target_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if target_len is not None and len(mask_idc) > target_len:
mask_idc = rng.choice(mask_idc, target_len, replace=False)
mask[i, mask_idc] = True
if target_len is not None and len(mask_idc) < target_len:
unmasked = np.flatnonzero(~mask[i])
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
mask[i, to_mask] = True
if mask_dropout > 0:
masked = np.flatnonzero(mask[i])
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
to_drop = rng.choice(masked, num_holes, replace=False)
mask[i, to_drop] = False
return mask
def _to_int_tuple(s: str):
return tuple(map(int, s.split(",")))
class HubertModel(nn.Module):
def __init__(
self,
cfg,
) -> None:
super().__init__()
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
self.embed = feature_enc_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias,
)
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate
encoder_input_dim = _to_int_tuple(cfg.encoder_dim)[0]
encoder_output_dim = max(_to_int_tuple(cfg.encoder_dim))
self.post_extract_proj = (
nn.Linear(self.embed, encoder_input_dim)
if self.embed != encoder_input_dim
else None
)
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.feature_grad_mult = cfg.feature_grad_mult
self.logit_temp = cfg.logit_temp
self.skip_masked = cfg.skip_masked
self.skip_nomask = cfg.skip_nomask
self.mask_emb = nn.Parameter(torch.FloatTensor(encoder_input_dim).uniform_())
self.encoder = Zipformer2(
output_downsampling_factor=1,
downsampling_factor=_to_int_tuple(cfg.downsampling_factor),
num_encoder_layers=_to_int_tuple(cfg.num_encoder_layers),
encoder_dim=_to_int_tuple(cfg.encoder_dim),
encoder_unmasked_dim=_to_int_tuple(cfg.encoder_unmasked_dim),
query_head_dim=_to_int_tuple(cfg.query_head_dim),
pos_head_dim=_to_int_tuple(cfg.pos_head_dim),
value_head_dim=_to_int_tuple(cfg.value_head_dim),
pos_dim=cfg.pos_dim,
num_heads=_to_int_tuple(cfg.num_heads),
feedforward_dim=_to_int_tuple(cfg.feedforward_dim),
cnn_module_kernel=_to_int_tuple(cfg.cnn_module_kernel),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
)
self.layer_norm = LayerNorm(self.embed)
self.untie_final_proj = cfg.untie_final_proj
self.final_proj = nn.Linear(encoder_output_dim, sum(cfg.num_classes))
# modules below are not needed during fine-tuning
self.num_classes = cfg.num_classes
self.pred_masked_weight = cfg.pred_masked_weight
self.pred_nomask_weight = cfg.pred_nomask_weight
self.loss_weights = cfg.loss_weights
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
super().upgrade_state_dict_named(state_dict, name)
return state_dict
def apply_mask(self, x, padding_mask, target_list):
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb.to(x.dtype)
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
return x, mask_indices
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.feature_extractor(source)
return features
def forward_targets(
self,
features: torch.Tensor,
target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz = features.size(2)
targ_tsz = min([t.size(1) for t in target_list])
if self.feat2tar_ratio * feat_tsz > targ_tsz:
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
features = features[..., :feat_tsz]
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
target_list = [t[:, target_inds.long()] for t in target_list]
return features, target_list
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask
def forward(
self,
source: torch.Tensor,
target_list: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = True,
features_only: bool = False,
output_layer: Optional[int] = None,
):
"""output layer is 1-based"""
features = self.forward_features(source)
if target_list is not None:
features, target_list = self.forward_targets(features, target_list)
features_pen = features.float().pow(2).mean()
features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
if mask:
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
else:
x = features
mask_indices = None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float -> (T, B, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x = x.transpose(0, 1)
x, x_lens = self.encoder(x, (~padding_mask).sum(dim=-1))
x = x.transpose(0, 1)
if features_only:
return {"x": x, "padding_mask": padding_mask, "features": features}
if not self.skip_masked:
masked_indices = torch.logical_and(~padding_mask, mask_indices)
proj_x_m = self.final_proj(x[masked_indices])
proj_x_m /= self.logit_temp
logit_m_list = [proj_x_m for _ in range(len(target_list))]
else:
logit_m_list = [None for _ in target_list]
if not self.skip_nomask:
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
proj_x_u = self.final_proj(x[nomask_indices])
proj_x_u /= self.logit_temp
logit_u_list = [proj_x_u for _ in range(len(target_list))]
else:
logit_u_list = [None for _ in target_list]
# result = {
# "logit_m_list": logit_m_list,
# "logit_u_list": logit_u_list,
# "padding_mask": padding_mask,
# "features_pen": features_pen,
# }
targ_m_list = target_list[0][masked_indices]
targ_m_list = targ_m_list.long()
targ_m_list = [targ_m_list for _ in range(len(target_list))]
targ_u_list = target_list[0][nomask_indices]
targ_u_list = targ_u_list.long()
targ_u_list = [targ_u_list for _ in range(len(target_list))]
return self.compute_loss(
logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
)
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
res = self.forward(
source,
padding_mask=padding_mask,
mask=mask,
features_only=True,
output_layer=output_layer,
)
feature = res["features"] if ret_conv else res["x"]
return feature, res["padding_mask"]
def get_logits(self, net_output, is_masked=True):
if is_masked:
logits_list = net_output["logit_m_list"]
else:
logits_list = net_output["logit_u_list"]
logits_list = [x.float() for x in logits_list if x is not None]
return logits_list
def get_targets(self, net_output, is_masked=True):
logits_list = self.get_logits(net_output, is_masked)
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
return targets_list
def get_extra_losses(self, net_output):
extra_losses = []
names = []
if "features_pen" in net_output:
extra_losses.append(net_output["features_pen"])
names.append("features_pen")
return extra_losses, names
def remove_pretraining_modules(self):
self.final_proj = None
def compute_loss(
self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
):
loss = 0.0
sample_size = 0
logging_output = {}
reduce = True
reduction = "sum" if reduce else "none"
loss_m_list = []
logp_m_list = [x.float() for x in logit_m_list if x is not None]
logp_m_list = torch.cat(logp_m_list)
targ_m_list = torch.cat(targ_m_list)
loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction)
loss_m_list.append(loss_m)
logging_output[f"loss_m_0"] = loss_m.detach().item()
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
if self.pred_masked_weight > 0:
loss += self.pred_masked_weight * sum(loss_m_list)
sample_size += len(targ_m_list)
loss_u_list = []
logp_u_list = [x.float() for x in logit_u_list if x is not None]
logp_u_list = torch.cat(logp_u_list)
targ_u_list = torch.cat(targ_u_list)
loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction)
loss_u_list.append(loss_u)
logging_output[f"loss_u_0"] = loss_u.detach().item()
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
if self.pred_nomask_weight > 0:
loss += self.pred_nomask_weight * sum(loss_u_list)
sample_size += len(targ_u_list)
if self.loss_weights is not None:
extra_losses = []
names = []
extra_losses.append(features_pen)
names.append("features_pen")
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
logging_output[f"loss_{n}"] = p.item()
logging_output = {
"loss": loss.item() if reduce else loss,
**logging_output,
}
# for lk in self.log_keys:
# if lk in net_output:
# logging_output[lk] = float((net_output[lk]))
def compute_correct(logits, target):
if logits.numel() == 0:
return 0, 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == target
min = logits.argmin(-1) == target
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
with torch.no_grad():
corr_m, count_m = compute_correct(logp_m_list, targ_m_list)
logging_output[f"correct_m_0"] = corr_m
logging_output[f"count_m_0"] = count_m
corr_u, count_u = compute_correct(logp_u_list, targ_u_list)
logging_output[f"correct_u_0"] = corr_u
logging_output[f"count_u_0"] = count_u
return loss, sample_size, logging_output

File diff suppressed because it is too large Load Diff