diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/wav2vec.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/wav2vec.py deleted file mode 100644 index af6604da1..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/wav2vec.py +++ /dev/null @@ -1,630 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass, field -import logging -import math -from typing import Optional, Tuple -from omegaconf import II -import sys - -import torch -import torch.nn as nn -import torch.nn.functional as F -from fairseq.dataclass import ChoiceEnum, FairseqDataclass -from fairseq.models import BaseFairseqModel, register_model -from fairseq.modules import ( - Fp32GroupNorm, - Fp32LayerNorm, - GumbelVectorQuantizer, - KmeansVectorQuantizer, - TransposeLast, -) -from fairseq.tasks import FairseqTask -from fairseq.utils import buffered_arange - - -logger = logging.getLogger(__name__) - - -AGGREGATOR_CHOICES = ChoiceEnum(["cnn", "gru"]) -PROJECT_FEATURES_CHOICES = ChoiceEnum(["none", "same", "new"]) -ACTIVATION_CHOICES = ChoiceEnum(["relu", "gelu"]) -VQ_TYPE_CHOICES = ChoiceEnum(["none", "gumbel", "kmeans"]) - - -@dataclass -class Wav2VecConfig(FairseqDataclass): - prediction_steps: int = field( - default=12, metadata={"help": "number of steps ahead to predict"} - ) - sample_distance: Optional[int] = field( - default=None, - metadata={ - "help": "sample distance from target. does not work properly with cross-sampling" - }, - ) - cross_sample_negatives: int = field( - default=0, metadata={"help": "num of cross sampled negatives"} - ) - num_negatives: int = field( - default=10, metadata={"help": "num of sampled negatives"} - ) - conv_feature_layers: str = field( - default="[(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)]", - metadata={ - "help": "convolutional feature extraction layers [(dim, kernel_size, stride), ...]" - }, - ) - conv_aggregator_layers: str = field( - default="[(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)]", - metadata={ - "help": "convolutional aggregator layers [(dim, kernel_size, stride), ...]" - }, - ) - dropout: float = field( - default=0.0, metadata={"help": "dropout to apply within the model"} - ) - dropout_features: float = field( - default=0.0, metadata={"help": "dropout to apply to the features"} - ) - dropout_agg: float = field( - default=0.0, metadata={"help": "dropout to apply after aggregation step"} - ) - aggregator: AGGREGATOR_CHOICES = field( - default="cnn", metadata={"help": "type of aggregator to use"} - ) - gru_dim: int = field(default=512, metadata={"help": "GRU dimensionality"}) - no_conv_bias: bool = field( - default=False, metadata={"help": "if set, does not learn bias for conv layers"} - ) - agg_zero_pad: bool = field( - default=False, - metadata={"help": "if set, zero pads in aggregator instead of repl pad"}, - ) - skip_connections_feat: bool = field( - default=False, - metadata={"help": "if set, adds skip connections to the feature extractor"}, - ) - skip_connections_agg: bool = field( - default=True, - metadata={"help": "if set, adds skip connections to the aggregator"}, - ) - residual_scale: float = field( - default=0.5, metadata={"help": "scales residual by sqrt(value)"} - ) - log_compression: bool = field( - default=True, - metadata={"help": "if set, adds a log compression to feature extractor"}, - ) - balanced_classes: bool = field( - default=False, - metadata={"help": "if set, loss is scaled to balance for number of negatives"}, - ) - project_features: PROJECT_FEATURES_CHOICES = field( - default="none", - metadata={ - "help": "if not none, features are projected using the (same or new) aggregator" - }, - ) - non_affine_group_norm: bool = field( - default=False, metadata={"help": "if set, group norm is not affine"} - ) - offset: str = field( - default="auto", - metadata={ - "help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value" - }, - ) - activation: ACTIVATION_CHOICES = field( - default="relu", - metadata={ - "help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value" - }, - ) - vq_type: VQ_TYPE_CHOICES = field( - default="none", metadata={"help": "which type of quantizer to use"} - ) - vq_vars: int = field( - default=320, - metadata={"help": "project to this many vector quantized variables per group"}, - ) - vq_groups: int = field( - default=2, metadata={"help": "number of groups of latent variables"} - ) - vq_dim: int = field( - default=0, - metadata={ - "help": "uses this dimensionality for quantized vectors. 0 to use model dim // groups" - }, - ) - vq_depth: int = field( - default=1, metadata={"help": "number of layers for vq weight projection"} - ) - combine_groups: bool = field( - default=False, metadata={"help": "if set, variables are shared among groups"} - ) - vq_temp: Tuple[float, float, float] = field( - default=(2.0, 0.5, 0.999995), - metadata={ - "help": "temperature for latent variable sampling with gumbel softmax. should be a tuple of 3 values (start, end, decay)" - }, - ) - vq_gamma: float = field( - default=0.25, - metadata={"help": "gamma parameter for kmeans style vector quantization"}, - ) - infonce: bool = II("criterion.infonce") - - -@register_model("wav2vec", dataclass=Wav2VecConfig) -class Wav2VecModel(BaseFairseqModel): - @classmethod - def build_model(cls, cfg: Wav2VecConfig, task: FairseqTask): - """Build a new model instance.""" - - model = Wav2VecModel(cfg) - logger.info(model) - return model - - def __init__(self, cfg: Wav2VecConfig): - super().__init__() - - self.prediction_steps = cfg.prediction_steps - offset = cfg.offset - - if cfg.activation == "relu": - activation = nn.ReLU() - elif cfg.activation == "gelu": - activation = nn.GELU() - else: - raise Exception("unknown activation " + cfg.activation) - - feature_enc_layers = eval(cfg.conv_feature_layers) - self.feature_extractor = ConvFeatureExtractionModel( - conv_layers=feature_enc_layers, - dropout=0.0, - log_compression=cfg.log_compression, - skip_connections=cfg.skip_connections_feat, - residual_scale=cfg.residual_scale, - non_affine_group_norm=cfg.non_affine_group_norm, - activation=activation, - ) - embed = feature_enc_layers[-1][0] - - self.vector_quantizer = None - if cfg.vq_type == "gumbel": - self.vector_quantizer = GumbelVectorQuantizer( - dim=embed, - num_vars=cfg.vq_vars, - temp=cfg.vq_temp, - groups=cfg.vq_groups, - combine_groups=cfg.combine_groups, - vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, - time_first=False, - activation=activation, - weight_proj_depth=cfg.vq_depth, - weight_proj_factor=2, - ) - elif cfg.vq_type == "kmeans": - self.vector_quantizer = KmeansVectorQuantizer( - dim=embed, - num_vars=cfg.vq_vars, - groups=cfg.vq_groups, - combine_groups=cfg.combine_groups, - vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, - time_first=False, - gamma=cfg.vq_gamma, - ) - else: - assert ( - cfg.vq_type == "none" or cfg.vq_type is None - ), "Unknown quantizer type" - - if cfg.offset == "auto": - jin = 0 - rin = 0 - for _, k, stride in feature_enc_layers: - if rin == 0: - rin = k - rin = rin + (k - 1) * jin - if jin == 0: - jin = stride - else: - jin *= stride - offset = math.ceil(rin / jin) - - offset = int(offset) - - def make_aggregator(): - if cfg.aggregator == "cnn": - agg_layers = eval(cfg.conv_aggregator_layers) - agg_dim = agg_layers[-1][0] - feature_aggregator = ConvAggegator( - conv_layers=agg_layers, - embed=embed, - dropout=cfg.dropout, - skip_connections=cfg.skip_connections_agg, - residual_scale=cfg.residual_scale, - non_affine_group_norm=cfg.non_affine_group_norm, - conv_bias=not cfg.no_conv_bias, - zero_pad=cfg.agg_zero_pad, - activation=activation, - ) - elif cfg.aggregator == "gru": - agg_dim = cfg.gru_dim - feature_aggregator = nn.Sequential( - TransposeLast(), - nn.GRU( - input_size=embed, - hidden_size=agg_dim, - num_layers=1, - dropout=cfg.dropout, - ), - TransposeLast(deconstruct_idx=0), - ) - else: - raise Exception("unknown aggregator type " + cfg.aggregator) - - return feature_aggregator, agg_dim - - self.feature_aggregator, agg_dim = make_aggregator() - - self.wav2vec_predictions = Wav2VecPredictionsModel( - in_dim=agg_dim, - out_dim=embed, - prediction_steps=cfg.prediction_steps, - n_negatives=cfg.num_negatives, - cross_sample_negatives=cfg.cross_sample_negatives, - sample_distance=cfg.sample_distance, - dropout=cfg.dropout, - offset=offset, - balanced_classes=cfg.balanced_classes, - infonce=cfg.infonce, - ) - - self.dropout_feats = nn.Dropout(p=cfg.dropout_features) - self.dropout_agg = nn.Dropout(p=cfg.dropout_agg) - - if cfg.project_features == "none": - self.project_features = None - elif cfg.project_features == "same": - self.project_features = self.feature_aggregator - elif cfg.project_features == "new": - self.project_features, _ = make_aggregator() - - def forward(self, source): - result = {} - - features = self.feature_extractor(source) - if self.vector_quantizer: - q_res = self.vector_quantizer(features) - features = q_res["x"] - for k in q_res.keys(): - if k != "x": - result[k] = q_res[k] - - x = self.dropout_feats(features) - x = self.feature_aggregator(x) - x = self.dropout_agg(x) - - if self.project_features is not None: - features = self.project_features(features) - x, targets = self.wav2vec_predictions(x, features) - result["cpc_logits"] = x - result["cpc_targets"] = targets - - return result - - def upgrade_state_dict_named(self, state_dict, name): - super().upgrade_state_dict_named(state_dict, name) - - def max_positions(self): - """Maximum length supported by the model.""" - return sys.maxsize - - def get_logits(self, net_output): - logits = net_output["cpc_logits"] - return logits - - def get_targets(self, sample, net_output): - t = net_output["cpc_targets"] - if isinstance(t, tuple): - t = t[0] - return t.contiguous() - - def get_target_weights(self, targets, net_output): - targets = net_output["cpc_targets"] - if isinstance(targets, tuple) and targets[-1] is not None: - return targets[-1] - return None - - def get_extra_losses(self, net_output): - loss = None - if "prob_perplexity" in net_output: - loss = net_output["num_vars"] - net_output["prob_perplexity"] - elif "kmeans_loss" in net_output: - loss = net_output["kmeans_loss"] - - return loss - - -def norm_block(is_layer_norm, dim, affine=True): - if is_layer_norm: - mod = nn.Sequential( - TransposeLast(), - Fp32LayerNorm(dim, elementwise_affine=affine), - TransposeLast(), - ) - else: - mod = Fp32GroupNorm(1, dim, affine=affine) - - return mod - - -class ConvFeatureExtractionModel(nn.Module): - def __init__( - self, - conv_layers, - dropout, - log_compression, - skip_connections, - residual_scale, - non_affine_group_norm, - activation, - ): - super().__init__() - - def block(n_in, n_out, k, stride): - return nn.Sequential( - nn.Conv1d(n_in, n_out, k, stride=stride, bias=False), - nn.Dropout(p=dropout), - norm_block( - is_layer_norm=False, dim=n_out, affine=not non_affine_group_norm - ), - activation, - ) - - in_d = 1 - self.conv_layers = nn.ModuleList() - for dim, k, stride in conv_layers: - self.conv_layers.append(block(in_d, dim, k, stride)) - in_d = dim - - self.log_compression = log_compression - self.skip_connections = skip_connections - self.residual_scale = math.sqrt(residual_scale) - - def forward(self, x): - # BxT -> BxCxT - x = x.unsqueeze(1) - - for conv in self.conv_layers: - residual = x - x = conv(x) - if self.skip_connections and x.size(1) == residual.size(1): - tsz = x.size(2) - r_tsz = residual.size(2) - residual = residual[..., :: r_tsz // tsz][..., :tsz] - x = (x + residual) * self.residual_scale - - if self.log_compression: - x = x.abs() - x = x + 1 - x = x.log() - - return x - - -class ZeroPad1d(nn.Module): - def __init__(self, pad_left, pad_right): - super().__init__() - self.pad_left = pad_left - self.pad_right = pad_right - - def forward(self, x): - return F.pad(x, (self.pad_left, self.pad_right)) - - -class ConvAggegator(nn.Module): - def __init__( - self, - conv_layers, - embed, - dropout, - skip_connections, - residual_scale, - non_affine_group_norm, - conv_bias, - zero_pad, - activation, - ): - super().__init__() - - def block(n_in, n_out, k, stride): - # padding dims only really make sense for stride = 1 - ka = k // 2 - kb = ka - 1 if k % 2 == 0 else ka - - pad = ( - ZeroPad1d(ka + kb, 0) if zero_pad else nn.ReplicationPad1d((ka + kb, 0)) - ) - - return nn.Sequential( - pad, - nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias), - nn.Dropout(p=dropout), - norm_block(False, n_out, affine=not non_affine_group_norm), - activation, - ) - - in_d = embed - self.conv_layers = nn.ModuleList() - self.residual_proj = nn.ModuleList() - for dim, k, stride in conv_layers: - if in_d != dim and skip_connections: - self.residual_proj.append(nn.Conv1d(in_d, dim, 1, bias=False)) - else: - self.residual_proj.append(None) - - self.conv_layers.append(block(in_d, dim, k, stride)) - in_d = dim - self.conv_layers = nn.Sequential(*self.conv_layers) - self.skip_connections = skip_connections - self.residual_scale = math.sqrt(residual_scale) - - def forward(self, x): - for rproj, conv in zip(self.residual_proj, self.conv_layers): - residual = x - x = conv(x) - if self.skip_connections: - if rproj is not None: - residual = rproj(residual) - x = (x + residual) * self.residual_scale - return x - - -class Wav2VecPredictionsModel(nn.Module): - def __init__( - self, - in_dim, - out_dim, - prediction_steps, - n_negatives, - cross_sample_negatives, - sample_distance, - dropout, - offset, - balanced_classes, - infonce, - ): - super().__init__() - - self.n_negatives = n_negatives - self.cross_sample_negatives = cross_sample_negatives - self.sample_distance = sample_distance - self.project_to_steps = nn.ConvTranspose2d( - in_dim, out_dim, (1, prediction_steps) - ) - self.dropout = nn.Dropout(p=dropout) - self.offset = offset - self.balanced_classes = balanced_classes - self.infonce = infonce - - def sample_negatives(self, y): - bsz, fsz, tsz = y.shape - - y = y.transpose(0, 1) # BCT -> CBT - y = y.contiguous().view(fsz, -1) # CBT => C(BxT) - - cross_high = tsz * bsz - high = tsz if self.sample_distance is None else min(tsz, self.sample_distance) - assert high > 1 - - neg_idxs = torch.randint(low=0, high=high, size=(bsz, self.n_negatives * tsz)) - - with torch.no_grad(): - if self.n_negatives > 0: - tszs = ( - buffered_arange(tsz) - .unsqueeze(-1) - .expand(-1, self.n_negatives) - .flatten() - ) - - neg_idxs = torch.randint( - low=0, high=high - 1, size=(bsz, self.n_negatives * tsz) - ) - neg_idxs[neg_idxs >= tszs] += 1 - - if self.cross_sample_negatives > 0: - tszs = ( - buffered_arange(tsz) - .unsqueeze(-1) - .expand(-1, self.cross_sample_negatives) - .flatten() - ) - - cross_neg_idxs = torch.randint( - low=0, - high=cross_high - 1, - size=(bsz, self.cross_sample_negatives * tsz), - ) - cross_neg_idxs[cross_neg_idxs >= tszs] += 1 - - if self.n_negatives > 0: - for i in range(1, bsz): - neg_idxs[i] += i * high - else: - neg_idxs = cross_neg_idxs - - if self.cross_sample_negatives > 0 and self.n_negatives > 0: - neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) - - negs = y[..., neg_idxs.view(-1)] - negs = negs.view( - fsz, bsz, self.n_negatives + self.cross_sample_negatives, tsz - ).permute( - 2, 1, 0, 3 - ) # to NxBxCxT - - return negs - - def forward(self, x, y): - - x = x.unsqueeze(-1) - x = self.project_to_steps(x) # BxCxTxS - x = self.dropout(x) - - negatives = self.sample_negatives(y) - y = y.unsqueeze(0) - targets = torch.cat([y, negatives], dim=0) # Copies x B x C x T - - copies = targets.size(0) - bsz, dim, tsz, steps = x.shape - steps = min(steps, tsz - self.offset) - - predictions = x.new( - bsz * copies * (tsz - self.offset + 1) * steps - - ((steps + 1) * steps // 2) * copies * bsz - ) - if self.infonce: - labels = predictions.new_full( - (predictions.shape[0] // copies,), 0, dtype=torch.long - ) - else: - labels = torch.zeros_like(predictions) - weights = ( - torch.full_like(labels, 1 / self.n_negatives) - if self.balanced_classes and not self.infonce - else None - ) - - start = end = 0 - for i in range(steps): - offset = i + self.offset - end = start + (tsz - offset) * bsz * copies - if self.infonce: - predictions[start:end] = torch.einsum( - "bct,nbct->tbn", x[..., :-offset, i], targets[..., offset:] - ).flatten() - else: - pos_num = (end - start) // copies - predictions[start:end] = torch.einsum( - "bct,nbct->nbt", x[..., :-offset, i], targets[..., offset:] - ).flatten() - labels[start : start + pos_num] = 1.0 - if weights is not None: - weights[start : start + pos_num] = 1.0 - start = end - assert end == predictions.numel(), "{} != {}".format(end, predictions.numel()) - - if self.infonce: - predictions = predictions.view(-1, copies) - else: - if weights is not None: - labels = (labels, weights) - - return predictions, labels