# Copyright (c) Facebook, Inc. and its affiliates. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import argparse import logging from typing import Dict, List, Optional, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from scaling import ScheduledFloat from utils import GradMultiply, LayerNorm from wav2vec2_module import ConvFeatureExtractionModel from zipformer import Zipformer2 def compute_mask_indices( shape: Tuple[int, int], padding_mask: Optional[torch.Tensor], mask_prob: float, mask_length: int, mask_type: str = "static", mask_other: float = 0.0, min_masks: int = 0, no_overlap: bool = False, min_space: int = 0, require_same_masks: bool = True, mask_dropout: float = 0.0, add_masks: bool = False, seed: Optional[int] = None, epoch: Optional[int] = None, indices: Optional[torch.Tensor] = None, idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset ) -> np.ndarray: """ Computes random mask spans for a given shape Args: shape: the the shape for which to compute masks. should be of size 2 where first element is batch size and 2nd is timesteps padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by number of timesteps divided by length of mask span to mask approximately this percentage of all elements. however due to overlaps, the actual number will be smaller (unless no_overlap is True) mask_type: how to compute mask lengths static = fixed size uniform = sample from uniform distribution [mask_other, mask_length*2] normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element poisson = sample from possion distribution with lambda = mask length min_masks: minimum number of masked spans no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample mask_dropout: randomly dropout this percentage of masks in each example """ bsz, all_sz = shape mask = np.full((bsz, all_sz), False) if num_mask_ver == 1: all_num_mask = int( # add a random number for probabilistic rounding mask_prob * all_sz / float(mask_length) + np.random.rand() ) all_num_mask = max(min_masks, all_num_mask) mask_idcs = [] for i in range(bsz): if seed is not None and epoch is not None and indices is not None: seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) else: seed_i = None rng = np.random.default_rng(seed_i) if padding_mask is not None: sz = all_sz - padding_mask[i].long().sum().item() assert sz >= 0, sz else: sz = all_sz if num_mask_ver == 1: if padding_mask is not None: num_mask = int( # add a random number for probabilistic rounding mask_prob * sz / float(mask_length) + np.random.rand() ) num_mask = max(min_masks, num_mask) else: num_mask = all_num_mask elif num_mask_ver == 2: num_mask = int( # add a random number for probabilistic rounding mask_prob * sz / float(mask_length) + rng.random() ) num_mask = max(min_masks, num_mask) else: raise ValueError() if mask_type == "static": lengths = np.full(num_mask, mask_length) elif mask_type == "uniform": lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) elif mask_type == "normal": lengths = rng.normal(mask_length, mask_other, size=num_mask) lengths = [max(1, int(round(x))) for x in lengths] elif mask_type == "poisson": lengths = rng.poisson(mask_length, size=num_mask) lengths = [int(round(x)) for x in lengths] else: raise Exception("unknown mask selection " + mask_type) if sum(lengths) == 0: if mask_type == "static": raise ValueError(f"this should never happens") else: lengths = [min(mask_length, sz - 1)] if no_overlap: mask_idc = [] def arrange(s, e, length, keep_length): span_start = rng.randint(s, e - length) mask_idc.extend(span_start + i for i in range(length)) new_parts = [] if span_start - s - min_space >= keep_length: new_parts.append((s, span_start - min_space + 1)) if e - span_start - length - min_space > keep_length: new_parts.append((span_start + length + min_space, e)) return new_parts parts = [(0, sz)] min_length = min(lengths) for length in sorted(lengths, reverse=True): lens = np.fromiter( (e - s if e - s >= length + min_space else 0 for s, e in parts), np.int, ) l_sum = np.sum(lens) if l_sum == 0: break probs = lens / np.sum(lens) c = rng.choice(len(parts), p=probs) s, e = parts.pop(c) parts.extend(arrange(s, e, length, min_length)) mask_idc = np.asarray(mask_idc) else: if idc_select_ver == 1: min_len = min(lengths) if sz - min_len <= num_mask: min_len = sz - num_mask - 1 mask_idc = rng.choice(sz - min_len, num_mask, replace=False) elif idc_select_ver == 2: mask_idc = rng.choice(sz, num_mask, replace=False) else: raise ValueError() mask_idc = np.asarray( [ mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j]) ] ) mask_idc = np.unique(mask_idc[mask_idc < sz]) if len(mask_idc) >= sz: raise ValueError( ( f"the entire sequence is masked. " f"sz={sz}; mask_idc[mask_idc]; " f"index={indices[i] if indices is not None else None}" ) ) mask_idcs.append(mask_idc) target_len = None if require_same_masks: if add_masks: target_len = max([len(m) for m in mask_idcs]) else: target_len = min([len(m) for m in mask_idcs]) for i, mask_idc in enumerate(mask_idcs): if target_len is not None and len(mask_idc) > target_len: mask_idc = rng.choice(mask_idc, target_len, replace=False) mask[i, mask_idc] = True if target_len is not None and len(mask_idc) < target_len: unmasked = np.flatnonzero(~mask[i]) to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False) mask[i, to_mask] = True if mask_dropout > 0: masked = np.flatnonzero(mask[i]) num_holes = np.rint(len(masked) * mask_dropout).astype(int) to_drop = rng.choice(masked, num_holes, replace=False) mask[i, to_drop] = False return mask def _to_int_tuple(s: str): return tuple(map(int, s.split(","))) class HubertModel(nn.Module): def __init__( self, cfg, ) -> None: super().__init__() feature_enc_layers = eval(cfg.conv_feature_layers) # noqa self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias, ) feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate encoder_input_dim = _to_int_tuple(cfg.encoder_dim)[0] encoder_output_dim = max(_to_int_tuple(cfg.encoder_dim)) self.post_extract_proj = ( nn.Linear(self.embed, encoder_input_dim) if self.embed != encoder_input_dim else None ) self.mask_prob = cfg.mask_prob self.mask_selection = cfg.mask_selection self.mask_other = cfg.mask_other self.mask_length = cfg.mask_length self.no_mask_overlap = cfg.no_mask_overlap self.mask_min_space = cfg.mask_min_space self.mask_channel_prob = cfg.mask_channel_prob self.mask_channel_selection = cfg.mask_channel_selection self.mask_channel_other = cfg.mask_channel_other self.mask_channel_length = cfg.mask_channel_length self.no_mask_channel_overlap = cfg.no_mask_channel_overlap self.mask_channel_min_space = cfg.mask_channel_min_space self.dropout_input = nn.Dropout(cfg.dropout_input) self.dropout_features = nn.Dropout(cfg.dropout_features) self.feature_grad_mult = cfg.feature_grad_mult self.logit_temp = cfg.logit_temp self.skip_masked = cfg.skip_masked self.skip_nomask = cfg.skip_nomask self.mask_emb = nn.Parameter(torch.FloatTensor(encoder_input_dim).uniform_()) self.encoder = Zipformer2( output_downsampling_factor=1, downsampling_factor=_to_int_tuple(cfg.downsampling_factor), num_encoder_layers=_to_int_tuple(cfg.num_encoder_layers), encoder_dim=_to_int_tuple(cfg.encoder_dim), encoder_unmasked_dim=_to_int_tuple(cfg.encoder_unmasked_dim), query_head_dim=_to_int_tuple(cfg.query_head_dim), pos_head_dim=_to_int_tuple(cfg.pos_head_dim), value_head_dim=_to_int_tuple(cfg.value_head_dim), pos_dim=cfg.pos_dim, num_heads=_to_int_tuple(cfg.num_heads), feedforward_dim=_to_int_tuple(cfg.feedforward_dim), cnn_module_kernel=_to_int_tuple(cfg.cnn_module_kernel), dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), warmup_batches=4000.0, ) self.layer_norm = LayerNorm(self.embed) self.untie_final_proj = cfg.untie_final_proj self.final_proj = nn.Linear(encoder_output_dim, sum(cfg.num_classes)) # modules below are not needed during fine-tuning self.num_classes = cfg.num_classes self.pred_masked_weight = cfg.pred_masked_weight self.pred_nomask_weight = cfg.pred_nomask_weight self.loss_weights = cfg.loss_weights def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" super().upgrade_state_dict_named(state_dict, name) return state_dict def apply_mask(self, x, padding_mask, target_list): B, T, C = x.shape if self.mask_prob > 0: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space, ) mask_indices = torch.from_numpy(mask_indices).to(x.device) x[mask_indices] = self.mask_emb.to(x.dtype) else: mask_indices = None if self.mask_channel_prob > 0: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space, ) mask_channel_indices = ( torch.from_numpy(mask_channel_indices) .to(x.device) .unsqueeze(1) .expand(-1, T, -1) ) x[mask_channel_indices] = 0 return x, mask_indices def forward_features(self, source: torch.Tensor) -> torch.Tensor: if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult) else: with torch.no_grad(): features = self.feature_extractor(source) return features def forward_targets( self, features: torch.Tensor, target_list: List[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Trim features to ensure labels exist and then get aligned labels feat_tsz = features.size(2) targ_tsz = min([t.size(1) for t in target_list]) if self.feat2tar_ratio * feat_tsz > targ_tsz: feat_tsz = int(targ_tsz / self.feat2tar_ratio) features = features[..., :feat_tsz] target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio target_list = [t[:, target_inds.long()] for t in target_list] return features, target_list def forward_padding_mask( self, features: torch.Tensor, padding_mask: torch.Tensor, ) -> torch.Tensor: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) return padding_mask def forward( self, source: torch.Tensor, target_list: Optional[List[torch.Tensor]] = None, padding_mask: Optional[torch.Tensor] = None, mask: bool = True, features_only: bool = False, output_layer: Optional[int] = None, ): """output layer is 1-based""" features = self.forward_features(source) if target_list is not None: features, target_list = self.forward_targets(features, target_list) features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) features = self.layer_norm(features) unmasked_features = features.clone() if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask) if self.post_extract_proj is not None: features = self.post_extract_proj(features) features = self.dropout_input(features) unmasked_features = self.dropout_features(unmasked_features) if mask: x, mask_indices = self.apply_mask(features, padding_mask, target_list) else: x = features mask_indices = None # feature: (B, T, D), float # target: (B, T), long # x: (B, T, D), float -> (T, B, D), float # padding_mask: (B, T), bool # mask_indices: (B, T), bool x = x.transpose(0, 1) x, x_lens = self.encoder(x, (~padding_mask).sum(dim=-1)) x = x.transpose(0, 1) if features_only: return {"x": x, "padding_mask": padding_mask, "features": features} if not self.skip_masked: masked_indices = torch.logical_and(~padding_mask, mask_indices) proj_x_m = self.final_proj(x[masked_indices]) proj_x_m /= self.logit_temp logit_m_list = [proj_x_m for _ in range(len(target_list))] else: logit_m_list = [None for _ in target_list] if not self.skip_nomask: nomask_indices = torch.logical_and(~padding_mask, ~mask_indices) proj_x_u = self.final_proj(x[nomask_indices]) proj_x_u /= self.logit_temp logit_u_list = [proj_x_u for _ in range(len(target_list))] else: logit_u_list = [None for _ in target_list] # result = { # "logit_m_list": logit_m_list, # "logit_u_list": logit_u_list, # "padding_mask": padding_mask, # "features_pen": features_pen, # } targ_m_list = target_list[0][masked_indices] targ_m_list = targ_m_list.long() targ_m_list = [targ_m_list for _ in range(len(target_list))] targ_u_list = target_list[0][nomask_indices] targ_u_list = targ_u_list.long() targ_u_list = [targ_u_list for _ in range(len(target_list))] return self.compute_loss( logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen ) def extract_features( self, source: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, mask: bool = False, ret_conv: bool = False, output_layer: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: res = self.forward( source, padding_mask=padding_mask, mask=mask, features_only=True, output_layer=output_layer, ) feature = res["features"] if ret_conv else res["x"] return feature, res["padding_mask"] def get_logits(self, net_output, is_masked=True): if is_masked: logits_list = net_output["logit_m_list"] else: logits_list = net_output["logit_u_list"] logits_list = [x.float() for x in logits_list if x is not None] return logits_list def get_targets(self, net_output, is_masked=True): logits_list = self.get_logits(net_output, is_masked) targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list] return targets_list def get_extra_losses(self, net_output): extra_losses = [] names = [] if "features_pen" in net_output: extra_losses.append(net_output["features_pen"]) names.append("features_pen") return extra_losses, names def remove_pretraining_modules(self): self.final_proj = None def compute_loss( self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen ): loss = 0.0 sample_size = 0 logging_output = {} reduce = True reduction = "sum" if reduce else "none" loss_m_list = [] logp_m_list = [x.float() for x in logit_m_list if x is not None] logp_m_list = torch.cat(logp_m_list) targ_m_list = torch.cat(targ_m_list) loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction) loss_m_list.append(loss_m) logging_output[f"loss_m_0"] = loss_m.detach().item() assert self.pred_masked_weight == 0 or len(logp_m_list) > 0 if self.pred_masked_weight > 0: loss += self.pred_masked_weight * sum(loss_m_list) sample_size += len(targ_m_list) loss_u_list = [] logp_u_list = [x.float() for x in logit_u_list if x is not None] logp_u_list = torch.cat(logp_u_list) targ_u_list = torch.cat(targ_u_list) loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction) loss_u_list.append(loss_u) logging_output[f"loss_u_0"] = loss_u.detach().item() assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0 if self.pred_nomask_weight > 0: loss += self.pred_nomask_weight * sum(loss_u_list) sample_size += len(targ_u_list) if self.loss_weights is not None: extra_losses = [] names = [] extra_losses.append(features_pen) names.append("features_pen") if torch.is_tensor(extra_losses): extra_losses = [extra_losses] names = [names] if len(self.loss_weights) == 1 and len(extra_losses) != 1: self.loss_weights = [self.loss_weights[0]] * len(extra_losses) assert len(extra_losses) == len( self.loss_weights ), f"{len(extra_losses)}, {len(self.loss_weights)}" for p, n, coef in zip(extra_losses, names, self.loss_weights): if coef != 0 and p is not None: p = coef * p.float() * sample_size loss += p logging_output[f"loss_{n}"] = p.item() logging_output = { "loss": loss.item() if reduce else loss, **logging_output, } # for lk in self.log_keys: # if lk in net_output: # logging_output[lk] = float((net_output[lk])) def compute_correct(logits, target): if logits.numel() == 0: return 0, 0 else: assert logits.dim() > 1, logits.shape max = logits.argmax(-1) == target min = logits.argmin(-1) == target both = max & min corr = max.long().sum().item() - both.long().sum().item() count = max.numel() return corr, count with torch.no_grad(): corr_m, count_m = compute_correct(logp_m_list, targ_m_list) logging_output[f"correct_m_0"] = corr_m logging_output[f"count_m_0"] = count_m corr_u, count_u = compute_correct(logp_u_list, targ_u_list) logging_output[f"correct_u_0"] = corr_u logging_output[f"count_u_0"] = count_u return loss, sample_size, logging_output