diff --git a/egs/librispeech/ASR/local/data2vec_audio.py b/egs/librispeech/ASR/local/data2vec_audio.py deleted file mode 100644 index aaf23e6d3..000000000 --- a/egs/librispeech/ASR/local/data2vec_audio.py +++ /dev/null @@ -1,686 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import sys -sys.path.append('../pruned_transducer_stateless_d2v_v2') - -import argparse -import logging -import math -from dataclasses import dataclass, field -from typing import Optional - -from omegaconf import II - -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.distributed as dist - -from fairseq.data.data_utils import compute_mask_indices -from fairseq.models import BaseFairseqModel, register_model -from fairseq.models.wav2vec import ( - ConvFeatureExtractionModel, - Wav2Vec2Config, - TransformerEncoder, -) -from fairseq.modules import ( - GradMultiply, - LayerNorm, -) -from fairseq.utils import index_put -from utils import pad_to_multiple - -from convolution import ConvolutionModule - -logger = logging.getLogger().setLevel(logging.INFO) - -class TransformerEncoderAdapter(TransformerEncoder): - def __init__(self, args: Wav2Vec2Config): - super().__init__(args) - self.adapters = ResidualAdapterModule() - - for p in self.adapters.parameters(): - p.data /= 10. - #p.data = nn.Parameter(torch.zeros(p.size()).to('cuda')) - #p.data = nn.Parameter(torch.randn(p.size()).to('cuda')/20.) - - def forward(self, x, padding_mask=None, layer=None, tgt_layer=None): - x, layer_results = self.extract_features_with_adapter( - x, - padding_mask=padding_mask, - tgt_layer=tgt_layer - ) - - if self.layer_norm_first and layer is None: - x = self.layer_norm(x) - - return x, layer_results - - def extract_features_with_adapter( - self, - x, - padding_mask=None, - tgt_layer=None, - min_layer=0, - ): - - if padding_mask is not None: - x = index_put(x, padding_mask, 0) - - x_conv = self.pos_conv(x.transpose(1, 2)) - x_conv = x_conv.transpose(1, 2) - x = x + x_conv - - if not self.layer_norm_first: - x = self.layer_norm(x) - - # pad to the sequence length dimension - x, pad_length = pad_to_multiple( - x, self.required_seq_len_multiple, dim=-2, value=0 - ) - if pad_length > 0 and padding_mask is None: - padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) - padding_mask[:, -pad_length:] = True - else: - padding_mask, _ = pad_to_multiple( - padding_mask, self.required_seq_len_multiple, dim=-1, value=True - ) - x = F.dropout(x, p=self.dropout, training=self.training) - - # B x T x C -> T x B x C - x = x.transpose(0, 1) - - layer_results = [] - r = None - - for i, layer in enumerate(self.layers): - dropout_probability = np.random.random() if self.layerdrop > 0 else 1 - if not self.training or (dropout_probability > self.layerdrop): - x, (z, lr) = layer( - x, self_attn_padding_mask=padding_mask, need_weights=False, - ) - x = self.adapters(x, layer_id=i) - - if i >= min_layer: - layer_results.append((x, z, lr)) - - if i == tgt_layer: - r = x - break - - if r is not None: - x = r - - # T x B x C -> B x T x C - x = x.transpose(0, 1) - - # undo paddding - if pad_length > 0: - x = x[:, :-pad_length] - - def undo_pad(a, b, c): - return ( - a[:-pad_length], - b[:-pad_length] if b is not None else b, - c[:-pad_length], - ) - - layer_results = [undo_pad(*u) for u in layer_results] - - return x, layer_results - - -class ResidualAdapterModule(nn.Module): - """ - Implements a residual adapter based on https://arxiv.org/pdf/1909.08478.pdf - modules similar to the original residual adapter except layernorm location (first -> last) - """ - def __init__( - self, - embedding_dim: float = 768, - layer_num: int = 12, - proj_dim: float = 512, - ) -> None: - - super().__init__() - - self.type = 'linear' - - def build_adapter(embedding_dim, proj_dim, type_=self.type): - if type_ == 'conv': - return ConvolutionModule(768, 31) - else: - return nn.Sequential( - #nn.LayerNorm(embedding_dim), - nn.Linear(embedding_dim, proj_dim), - nn.ReLU(), - nn.Linear(proj_dim, embedding_dim), - nn.LayerNorm(embedding_dim), - ) - - self.adapter_layers = nn.ModuleList( - [build_adapter(embedding_dim, proj_dim, type_=self.type) for _ in range(layer_num)] - ) - - def forward(self, x, layer_id=-1): - x = x.transpose(0, 1) - residual = x - x = self.adapter_layers[layer_id](x) - x = residual + x - x = x.transpose(0, 1) - - return x - - -@dataclass -class Data2VecAudioConfig(Wav2Vec2Config): - - loss_beta: float = field( - default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"} - ) - loss_scale: Optional[float] = field( - default=None, - metadata={ - "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)" - }, - ) - average_top_k_layers: int = field( - default=8, metadata={"help": "how many layers to average"} - ) - - layer_norm_target_layer: bool = False - instance_norm_target_layer: bool = False - instance_norm_targets: bool = False - layer_norm_targets: bool = False - batch_norm_target_layer: bool = False - group_norm_target_layer: bool = False - - ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"}) - ema_end_decay: float = field( - default=0.9999, metadata={"help": "final ema decay rate"} - ) - - # when to finish annealing ema decay rate - ema_anneal_end_step: int = II("optimization.max_update") - - ema_transformer_only: bool = field( - default=True, - metadata={"help": "whether to momentum update only the transformer"}, - ) - ema_layers_only: bool = field( - default=True, - metadata={"help": "whether to momentum update only the transformer layers"}, - ) - - max_update: int = II("optimization.max_update") - - min_target_var: float = field( - default=0.1, metadata={"help": "stop training if target var falls below this"} - ) - min_pred_var: float = field( - default=0.01, - metadata={"help": "stop training if prediction var falls below this"}, - ) - - -def get_annealed_rate(start, end, curr_step, total_steps): - r = end - start - pct_remaining = 1 - curr_step / total_steps - return end - r * pct_remaining - - -@register_model("data2vec_audio", dataclass=Data2VecAudioConfig) -class Data2VecAudioModel(BaseFairseqModel): - def __init__(self, cfg: Data2VecAudioConfig): - super().__init__() - self.cfg = cfg - - feature_enc_layers = eval(cfg.conv_feature_layers) - self.extractor_embed = feature_enc_layers[-1][0] - - self.ema = None - self.embed = cfg.encoder_embed_dim - - self.average_top_k_layers = cfg.average_top_k_layers - self.loss_beta = cfg.loss_beta - self.loss_scale = cfg.loss_scale - - self.feature_extractor = ConvFeatureExtractionModel( - conv_layers=feature_enc_layers, - dropout=0.0, - mode=cfg.extractor_mode, - conv_bias=cfg.conv_bias, - ) - - self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim) - - self.mask_prob = cfg.mask_prob - self.mask_selection = cfg.mask_selection - self.mask_other = cfg.mask_other - self.mask_length = cfg.mask_length - self.no_mask_overlap = cfg.no_mask_overlap - self.mask_min_space = cfg.mask_min_space - - self.mask_channel_prob = cfg.mask_channel_prob - self.mask_channel_before = cfg.mask_channel_before - self.mask_channel_selection = cfg.mask_channel_selection - self.mask_channel_other = cfg.mask_channel_other - self.mask_channel_length = cfg.mask_channel_length - self.no_mask_channel_overlap = cfg.no_mask_channel_overlap - self.mask_channel_min_space = cfg.mask_channel_min_space - - self.dropout_input = nn.Dropout(cfg.dropout_input) - self.dropout_features = nn.Dropout(cfg.dropout_features) - - self.feature_grad_mult = cfg.feature_grad_mult - - self.mask_emb = nn.Parameter( - torch.FloatTensor(cfg.encoder_embed_dim).uniform_() - ) - - #self.encoder = TransformerEncoder(cfg) - self.encoder = TransformerEncoderAdapter(cfg) - self.layer_norm = LayerNorm(self.extractor_embed) - - self.final_proj = nn.Linear(self.embed, self.embed) - - self.num_updates = 0 - - ''' - def make_ema_teacher(self): - ema_config = EMAModuleConfig( - ema_decay=self.cfg.ema_decay, - ema_fp32=True, - ) - skip_keys = set() - if self.cfg.ema_layers_only: - self.cfg.ema_transformer_only = True - for k, _ in self.encoder.pos_conv.named_parameters(): - skip_keys.add(f"pos_conv.{k}") - - self.ema = EMAModule( - self.encoder if self.cfg.ema_transformer_only else self, - ema_config, - skip_keys=skip_keys, - ) - ''' - - def set_num_updates(self, num_updates): - super().set_num_updates(num_updates) - - ''' - if self.ema is None and self.final_proj is not None: - logger.info(f"making ema teacher") - self.make_ema_teacher() - elif self.training and self.ema is not None: - if self.cfg.ema_decay != self.cfg.ema_end_decay: - if num_updates >= self.cfg.ema_anneal_end_step: - decay = self.cfg.ema_end_decay - else: - decay = get_annealed_rate( - self.cfg.ema_decay, - self.cfg.ema_end_decay, - num_updates, - self.cfg.ema_anneal_end_step, - ) - self.ema.set_decay(decay) - if self.ema.get_decay() < 1: - self.ema.step(self.encoder if self.cfg.ema_transformer_only else self) - ''' - self.num_updates = num_updates - - def state_dict(self, destination=None, prefix="", keep_vars=False): - state = super().state_dict(destination, prefix, keep_vars) - - if self.ema is not None: - state[prefix + "_ema"] = self.ema.fp32_params - - return state - - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): - if self.ema is not None: - k = prefix + "_ema" - assert k in state_dict - self.ema.restore(state_dict[k], True) - del state_dict[k] - return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) - - @classmethod - def build_model(cls, cfg: Data2VecAudioConfig, task=None): - """Build a new model instance.""" - - return cls(cfg) - - def apply_mask( - self, - x, - padding_mask, - mask_indices=None, - mask_channel_indices=None, - ): - B, T, C = x.shape - - if self.mask_channel_prob > 0 and self.mask_channel_before: - mask_channel_indices = compute_mask_indices( - (B, C), - None, - self.mask_channel_prob, - self.mask_channel_length, - self.mask_channel_selection, - self.mask_channel_other, - no_overlap=self.no_mask_channel_overlap, - min_space=self.mask_channel_min_space, - ) - mask_channel_indices = ( - torch.from_numpy(mask_channel_indices) - .to(x.device) - .unsqueeze(1) - .expand(-1, T, -1) - ) - x[mask_channel_indices] = 0 - - if self.mask_prob > 0: - if mask_indices is None: - mask_indices = compute_mask_indices( - (B, T), - padding_mask, - self.mask_prob, - self.mask_length, - self.mask_selection, - self.mask_other, - min_masks=1, - no_overlap=self.no_mask_overlap, - min_space=self.mask_min_space, - require_same_masks=self.cfg.require_same_masks, - mask_dropout=self.cfg.mask_dropout, - ) - mask_indices = torch.from_numpy(mask_indices).to(x.device) - #x = index_put(x, mask_indices, self.mask_emb) - x = index_put(x, mask_indices, 0) - else: - mask_indices = None - - if self.mask_channel_prob > 0 and not self.mask_channel_before: - if mask_channel_indices is None: - mask_channel_indices = compute_mask_indices( - (B, C), - None, - self.mask_channel_prob, - self.mask_channel_length, - self.mask_channel_selection, - self.mask_channel_other, - no_overlap=self.no_mask_channel_overlap, - min_space=self.mask_channel_min_space, - ) - mask_channel_indices = ( - torch.from_numpy(mask_channel_indices) - .to(x.device) - .unsqueeze(1) - .expand(-1, T, -1) - ) - x = index_put(x, mask_channel_indices, 0) - - return x, mask_indices - - def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): - """ - Computes the output length of the convolutional layers - """ - - def _conv_out_length(input_length, kernel_size, stride): - return torch.floor((input_length - kernel_size) / stride + 1) - - conv_cfg_list = eval(self.cfg.conv_feature_layers) - - for i in range(len(conv_cfg_list)): - input_lengths = _conv_out_length( - input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2] - ) - - return input_lengths.to(torch.long) - - def forward( - self, - source, - padding_mask=None, - mask=True, - features_only=False, - layer=None, - mask_indices=None, - mask_channel_indices=None, - padding_count=None, - ): - features = source - - if self.feature_grad_mult > 0: - features = self.feature_extractor(features) - if self.feature_grad_mult != 1.0: - features = GradMultiply.apply(features, self.feature_grad_mult) - else: - with torch.no_grad(): - features = self.feature_extractor(features) - - features = features.transpose(1, 2) - - features = self.layer_norm(features) - - orig_padding_mask = padding_mask - - if padding_mask is not None and padding_mask.any(): - input_lengths = (1 - padding_mask.long()).sum(-1) - # apply conv formula to get real output_lengths - output_lengths = self._get_feat_extract_output_lengths(input_lengths) - - padding_mask = torch.zeros( - features.shape[:2], dtype=features.dtype, device=features.device - ) - - # these two operations makes sure that all values - # before the output lengths indices are attended to - padding_mask[ - ( - torch.arange(padding_mask.shape[0], device=padding_mask.device), - output_lengths - 1, - ) - ] = 1 - padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() - else: - padding_mask = None - - if self.post_extract_proj is not None: - features = self.post_extract_proj(features) - - pre_encoder_features = None - if self.cfg.ema_transformer_only: - pre_encoder_features = features.clone() - - features = self.dropout_input(features) - - if mask: - x, mask_indices = self.apply_mask( - features, - padding_mask, - mask_indices=mask_indices, - mask_channel_indices=mask_channel_indices, - ) - else: - x = features - mask_indices = None - - x, layer_results = self.encoder( - x, - padding_mask=padding_mask, - layer=layer, - ) - - if features_only: - return { - "x": x, - "padding_mask": padding_mask, - "layer_results": layer_results, - } - - result = { - "losses": {}, - } - - with torch.no_grad(): - self.ema.model.eval() - - if self.cfg.ema_transformer_only: - y, layer_results = self.ema.model.extract_features( - pre_encoder_features, - padding_mask=padding_mask, - min_layer=self.cfg.encoder_layers - self.average_top_k_layers, - ) - y = { - "x": y, - "padding_mask": padding_mask, - "layer_results": layer_results, - } - else: - y = self.ema.model.extract_features( - source=source, - padding_mask=orig_padding_mask, - mask=False, - ) - - target_layer_results = [l[2] for l in y["layer_results"]] - - permuted = False - if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer: - target_layer_results = [ - tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT - ] - permuted = True - - if self.cfg.batch_norm_target_layer: - target_layer_results = [ - F.batch_norm( - tl.float(), running_mean=None, running_var=None, training=True - ) - for tl in target_layer_results - ] - - if self.cfg.instance_norm_target_layer: - target_layer_results = [ - F.instance_norm(tl.float()) for tl in target_layer_results - ] - - if permuted: - target_layer_results = [ - tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC - ] - - if self.cfg.group_norm_target_layer: - target_layer_results = [ - F.layer_norm(tl.float(), tl.shape[-2:]) - for tl in target_layer_results - ] - - if self.cfg.layer_norm_target_layer: - target_layer_results = [ - F.layer_norm(tl.float(), tl.shape[-1:]) - for tl in target_layer_results - ] - - y = sum(target_layer_results) / len(target_layer_results) - - if self.cfg.layer_norm_targets: - y = F.layer_norm(y.float(), y.shape[-1:]) - - if self.cfg.instance_norm_targets: - y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2) - - if not permuted: - y = y.transpose(0, 1) - - y = y[mask_indices] - - x = x[mask_indices] - x = self.final_proj(x) - - sz = x.size(-1) - - if self.loss_beta == 0: - loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1) - else: - loss = F.smooth_l1_loss( - x.float(), y.float(), reduction="none", beta=self.loss_beta - ).sum(dim=-1) - - if self.loss_scale is not None: - scale = self.loss_scale - else: - scale = 1 / math.sqrt(sz) - - result["losses"]["regression"] = loss.sum() * scale - - if "sample_size" not in result: - result["sample_size"] = loss.numel() - - with torch.no_grad(): - result["target_var"] = self.compute_var(y) - result["pred_var"] = self.compute_var(x.float()) - - if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var: - logger.error( - f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting" - ) - raise Exception( - f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting" - ) - if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var: - logger.error( - f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting" - ) - raise Exception( - f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting" - ) - - if self.ema is not None: - result["ema_decay"] = self.ema.get_decay() * 1000 - - return result - - @staticmethod - def compute_var(y): - y = y.view(-1, y.size(-1)) - if dist.is_initialized(): - zc = torch.tensor(y.size(0)).cuda() - zs = y.sum(dim=0) - zss = (y ** 2).sum(dim=0) - - dist.all_reduce(zc) - dist.all_reduce(zs) - dist.all_reduce(zss) - - var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1)) - return torch.sqrt(var + 1e-6).mean() - else: - return torch.sqrt(y.var(dim=0) + 1e-6).mean() - - def extract_features( - self, source, padding_mask, mask=False, layer=None - ): - res = self.forward( - source, - padding_mask, - mask=mask, - features_only=True, - layer=layer, - ) - return res - - def remove_pretraining_modules(self, last_layer=None): - self.final_proj = None - self.ema = None - if last_layer is not None: - self.encoder.layers = nn.ModuleList( - l for i, l in enumerate(self.encoder.layers) if i <= last_layer - )