# Copyright 2023 (authors: Feiteng Li) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import random from typing import Dict, Iterator, List, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from icefall.utils import make_pad_mask from torchmetrics.classification import MulticlassAccuracy # from valle.data.input_strategies import PromptedFeatures # from valle.modules.embedding import SinePositionalEmbedding, TokenEmbedding # from valle.modules.transformer import ( # AdaptiveLayerNorm, # LayerNorm, # TransformerEncoder, # TransformerEncoderLayer, # ) from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS from .visualizer import visualize class PromptedFeatures: def __init__(self, prompts, features): self.prompts = prompts self.features = features def to(self, device): return PromptedFeatures( self.prompts.to(device), self.features.to(device) ) def sum(self): return self.features.sum() @property def ndim(self): return self.features.ndim @property def data(self): return (self.prompts, self.features) class TokenEmbedding(nn.Module): def __init__( self, dim_model: int, vocab_size: int, dropout: float = 0.0, ): super().__init__() self.vocab_size = vocab_size self.dim_model = dim_model self.dropout = torch.nn.Dropout(p=dropout) self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) @property def weight(self) -> torch.Tensor: return self.word_embeddings.weight def embedding(self, index: int) -> torch.Tensor: return self.word_embeddings.weight[index : index + 1] def forward(self, x: torch.Tensor): X = self.word_embeddings(x) X = self.dropout(X) return X class SinePositionalEmbedding(nn.Module): def __init__( self, dim_model: int, dropout: float = 0.0, scale: bool = False, alpha: bool = False, ): super().__init__() self.dim_model = dim_model self.x_scale = math.sqrt(dim_model) if scale else 1.0 self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) self.dropout = torch.nn.Dropout(p=dropout) self.reverse = False self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, 4000)) def extend_pe(self, x): """Reset the positional encodings.""" if self.pe is not None: if self.pe.size(1) >= x.size(1): if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device) return pe = torch.zeros(x.size(1), self.dim_model) if self.reverse: position = torch.arange( x.size(1) - 1, -1, -1.0, dtype=torch.float32 ).unsqueeze(1) else: position = torch.arange( 0, x.size(1), dtype=torch.float32 ).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.dim_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.dim_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.pe = pe.to(device=x.device, dtype=x.dtype).detach() def forward(self, x: torch.Tensor) -> torch.Tensor: self.extend_pe(x) output = x.unsqueeze(-1) if x.ndim == 2 else x output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] return self.dropout(output) class Transpose(nn.Identity): """(N, T, D) -> (N, D, T)""" def forward(self, input: torch.Tensor) -> torch.Tensor: return input.transpose(1, 2) _shape_t = Union[int, List[int], torch.Size] class MultiheadAttention(Module): r"""Allows the model to jointly attend to information from different representation subspaces as described in the paper: `Attention Is All You Need `_. Multi-Head Attention is defined as: .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. ``forward()`` will use a special optimized implementation if all of the following conditions are met: - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This restriction will be loosened in the future.) - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` - training is disabled (using ``.eval()``) - dropout is 0 - ``add_bias_kv`` is ``False`` - ``add_zero_attn`` is ``False`` - ``batch_first`` is ``True`` and the input is batched - ``kdim`` and ``vdim`` are equal to ``embed_dim`` - at most one of ``key_padding_mask`` or ``attn_mask`` is passed - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` nor ``attn_mask`` is passed If the optimized implementation is in use, a `NestedTensor `_ can be passed for ``query``/``key``/``value`` to represent padding more efficiently than using a padding mask. In this case, a `NestedTensor `_ will be returned, and an additional speedup proportional to the fraction of the input that is padding can be expected. Args: embed_dim: Total dimension of the model. num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). bias: If specified, adds bias to input / output projection layers. Default: ``True``. add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: ``False``. kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature). Examples:: >>> # xdoctest: +SKIP >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value) """ __constants__ = ["batch_first"] bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] def __init__( self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, linear1_cls=Linear, linear2_cls=Linear, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = ( self.kdim == embed_dim and self.vdim == embed_dim ) self.num_heads = num_heads self.dropout = dropout self.batch_first = batch_first self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" if add_bias_kv: self.bias_k = Parameter( torch.empty((1, 1, embed_dim), **factory_kwargs) ) self.bias_v = Parameter( torch.empty((1, 1, embed_dim), **factory_kwargs) ) else: self.bias_k = self.bias_v = None if linear1_cls == Linear: if not self._qkv_same_embed_dim: self.q_proj_weight = Parameter( torch.empty((embed_dim, embed_dim), **factory_kwargs) ) self.k_proj_weight = Parameter( torch.empty((embed_dim, self.kdim), **factory_kwargs) ) self.v_proj_weight = Parameter( torch.empty((embed_dim, self.vdim), **factory_kwargs) ) self.register_parameter("in_proj_weight", None) else: self.in_proj_weight = Parameter( torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) ) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) if bias: self.in_proj_bias = Parameter( torch.empty(3 * embed_dim, **factory_kwargs) ) else: self.register_parameter("in_proj_bias", None) self.out_proj = NonDynamicallyQuantizableLinear( embed_dim, embed_dim, bias=bias, **factory_kwargs ) self._reset_parameters() else: if not self._qkv_same_embed_dim: raise NotImplementedError else: self.in_proj_linear = linear1_cls( embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs ) self.in_proj_weight = self.in_proj_linear.weight self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) if bias: self.in_proj_bias = self.in_proj_linear.bias else: self.register_parameter("in_proj_bias", None) self.out_proj = linear2_cls( embed_dim, embed_dim, bias=bias, **factory_kwargs ) if self.bias_k is not None: xavier_normal_(self.bias_k) if self.bias_v is not None: xavier_normal_(self.bias_v) self.add_zero_attn = add_zero_attn def _reset_parameters(self): if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight) else: xavier_uniform_(self.q_proj_weight) xavier_uniform_(self.k_proj_weight) xavier_uniform_(self.v_proj_weight) if self.in_proj_bias is not None: constant_(self.in_proj_bias, 0.0) constant_(self.out_proj.bias, 0.0) if self.bias_k is not None: xavier_normal_(self.bias_k) if self.bias_v is not None: xavier_normal_(self.bias_v) def __setstate__(self, state): # Support loading old MultiheadAttention checkpoints generated by v1.1.0 if "_qkv_same_embed_dim" not in state: state["_qkv_same_embed_dim"] = True super(MultiheadAttention, self).__setstate__(state) def forward( self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against key-value pairs to produce the output. See "Attention Is All You Need" for more details. key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details. value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details. key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and byte masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. Default: ``True``. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight. average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) Outputs: - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the embedding dimension ``embed_dim``. - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. .. note:: `batch_first` argument is ignored for unbatched inputs. """ is_batched = query.dim() == 3 if key_padding_mask is not None: _kpm_dtype = key_padding_mask.dtype if _kpm_dtype != torch.bool and not torch.is_floating_point( key_padding_mask ): raise AssertionError( "only bool and floating types of key_padding_mask are supported" ) why_not_fast_path = "" if not is_batched: why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" elif query is not key or key is not value: # When lifting this restriction, don't forget to either # enforce that the dtypes all match or test cases where # they don't! why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" elif ( self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype ): why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" elif ( self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype ): # this case will fail anyway, but at least they'll get a useful error message. why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" elif self.training: why_not_fast_path = "training is enabled" elif not self.batch_first: why_not_fast_path = "batch_first was not True" elif self.bias_k is not None: why_not_fast_path = "self.bias_k was not None" elif self.bias_v is not None: why_not_fast_path = "self.bias_v was not None" elif self.dropout: why_not_fast_path = f"dropout was {self.dropout}, required zero" elif self.add_zero_attn: why_not_fast_path = "add_zero_attn was enabled" elif not self._qkv_same_embed_dim: why_not_fast_path = "_qkv_same_embed_dim was not True" elif attn_mask is not None: why_not_fast_path = "attn_mask was not None" elif query.is_nested and key_padding_mask is not None: why_not_fast_path = ( "key_padding_mask is not supported with NestedTensor input" ) elif self.num_heads % 2 == 1: why_not_fast_path = "num_heads is odd" elif torch.is_autocast_enabled(): why_not_fast_path = "autocast is enabled" if not why_not_fast_path: tensor_args = ( query, key, value, self.in_proj_weight, self.in_proj_bias, self.out_proj.weight, self.out_proj.bias, ) # We have to use list comprehensions below because TorchScript does not support # generator expressions. if torch.overrides.has_torch_function(tensor_args): why_not_fast_path = "some Tensor argument has_torch_function" elif not all( [ (x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args ] ): why_not_fast_path = ( "some Tensor argument is neither CUDA nor CPU" ) elif torch.is_grad_enabled() and any( [x is not None and x.requires_grad for x in tensor_args] ): why_not_fast_path = ( "grad is enabled and at least one of query or the " "input/output projection weights or biases requires_grad" ) if not why_not_fast_path: return torch._native_multi_head_attention( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.out_proj.weight, self.out_proj.bias, key_padding_mask if key_padding_mask is not None else attn_mask, need_weights, average_attn_weights, 1 if key_padding_mask is not None else 0 if attn_mask is not None else None, ) any_nested = query.is_nested or key.is_nested or value.is_nested assert not any_nested, ( "MultiheadAttention does not support NestedTensor outside of its fast path. " + f"The fast path was not hit because {why_not_fast_path}" ) if self.batch_first and is_batched: # make sure that the transpose op does not affect the "is" property if key is value: if query is key: query = key = value = query.transpose(1, 0) else: query, key = [x.transpose(1, 0) for x in (query, key)] value = key else: query, key, value = [ x.transpose(1, 0) for x in (query, key, value) ] if not self._qkv_same_embed_dim: attn_output, attn_output_weights = F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights, ) else: attn_output, attn_output_weights = F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, average_attn_weights=average_attn_weights, ) if self.batch_first and is_batched: return attn_output.transpose(1, 0), attn_output_weights else: return attn_output, attn_output_weights class LayerNorm(nn.Module): __constants__ = ["normalized_shape", "eps", "elementwise_affine"] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool def __init__( self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): # mypy error: incompatible types in assignment normalized_shape = (normalized_shape,) # type: ignore[assignment] self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter( torch.empty(self.normalized_shape, **factory_kwargs) ) self.bias = nn.Parameter( torch.empty(self.normalized_shape, **factory_kwargs) ) else: self.register_parameter("weight", None) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: if self.elementwise_affine: nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, input: Tensor, embedding: Any = None) -> Tensor: if isinstance(input, tuple): input, embedding = input return ( F.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps, ), embedding, ) assert embedding is None return F.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps ) def extra_repr(self) -> str: return ( "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) ) class AdaptiveLayerNorm(nn.Module): r"""Adaptive Layer Normalization""" def __init__(self, d_model, norm) -> None: super(AdaptiveLayerNorm, self).__init__() self.project_layer = nn.Linear(d_model, 2 * d_model) self.norm = norm self.d_model = d_model self.eps = self.norm.eps def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: if isinstance(input, tuple): input, embedding = input weight, bias = torch.split( self.project_layer(embedding), split_size_or_sections=self.d_model, dim=-1, ) return (weight * self.norm(input) + bias, embedding) weight, bias = torch.split( self.project_layer(embedding), split_size_or_sections=self.d_model, dim=-1, ) return weight * self.norm(input) + bias class TransformerEncoderLayer(nn.Module): __constants__ = ["batch_first", "norm_first"] def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None, linear1_self_attention_cls: nn.Module = nn.Linear, linear2_self_attention_cls: nn.Module = nn.Linear, linear1_feedforward_cls: nn.Module = nn.Linear, linear2_feedforward_cls: nn.Module = nn.Linear, layer_norm_cls: nn.Module = LayerNorm, layer_norm_eps: float = 1e-5, adaptive_layer_norm=False, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, linear1_cls=linear1_self_attention_cls, linear2_cls=linear2_self_attention_cls, **factory_kwargs, ) # Implementation of Feedforward model self.linear1 = linear1_feedforward_cls( d_model, dim_feedforward, **factory_kwargs ) self.dropout = nn.Dropout(dropout) self.linear2 = linear2_feedforward_cls( dim_feedforward, d_model, **factory_kwargs ) self.norm_first = norm_first self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str): activation = _get_activation_fn(activation) elif isinstance(activation, partial): activation = activation(d_model) # elif activation == BalancedDoubleSwish: # activation = BalancedDoubleSwish(d_model) # # We can't test self.activation in forward() in TorchScript, # # so stash some information about it instead. # if activation is F.relu or isinstance(activation, torch.nn.ReLU): # self.activation_relu_or_gelu = 1 # elif activation is F.gelu or isinstance(activation, torch.nn.GELU): # self.activation_relu_or_gelu = 2 # else: # self.activation_relu_or_gelu = 0 self.activation = activation norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) if layer_norm_cls == IdentityNorm: norm2 = BalancedBasicNorm( d_model, eps=layer_norm_eps, **factory_kwargs ) else: norm2 = layer_norm_cls( d_model, eps=layer_norm_eps, **factory_kwargs ) if adaptive_layer_norm: self.norm1 = AdaptiveLayerNorm(d_model, norm1) self.norm2 = AdaptiveLayerNorm(d_model, norm2) else: self.norm1 = norm1 self.norm2 = norm2 def __setstate__(self, state): super(TransformerEncoderLayer, self).__setstate__(state) if not hasattr(self, "activation"): self.activation = F.relu def forward( self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ x, stage_embedding = src, None is_src_tuple = False if isinstance(src, tuple): x, stage_embedding = src is_src_tuple = True if src_key_padding_mask is not None: _skpm_dtype = src_key_padding_mask.dtype if _skpm_dtype != torch.bool and not torch.is_floating_point( src_key_padding_mask ): raise AssertionError( "only bool and floating types of key_padding_mask are supported" ) if self.norm_first: x = x + self._sa_block( self.norm1(x, stage_embedding), src_mask, src_key_padding_mask, ) x = x + self._ff_block(self.norm2(x, stage_embedding)) else: x = self.norm1( x + self._sa_block(x, src_mask, src_key_padding_mask), stage_embedding, ) x = self.norm2(x + self._ff_block(x), stage_embedding) if is_src_tuple: return (x, stage_embedding) return x # self-attention block def _sa_block( self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], ) -> Tensor: x = self.self_attn( x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0] return self.dropout1(x) # feed forward block def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x) class TransformerEncoder(nn.Module): r"""TransformerEncoder is a stack of N encoder layers. Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. Args: encoder_layer: an instance of the TransformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). norm: the layer normalization component (optional). enable_nested_tensor: if True, input will automatically convert to nested tensor (and convert back on output). This will improve the overall performance of TransformerEncoder when padding rate is high. Default: ``True`` (enabled). Examples:: >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> out = transformer_encoder(src) """ __constants__ = ["norm"] def __init__(self, encoder_layer, num_layers, norm=None): super(TransformerEncoder, self).__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, return_layer_states: bool = False, ) -> Tensor: r"""Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). return_layer_states: return layers' state (optional). Shape: see the docs in Transformer class. """ if return_layer_states: layer_states = [] # layers' output output = src for mod in self.layers: output = mod( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, ) layer_states.append(output[0]) if self.norm is not None: output = self.norm(output) return layer_states, output output = src for mod in self.layers: output = mod( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask ) if self.norm is not None: output = self.norm(output) return output def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError( "activation should be relu/gelu, not {}".format(activation) ) class VALLE(nn.Module): """It implements https://arxiv.org/abs/2301.02111 "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers" """ def __init__( self, d_model: int, nhead: int, num_layers: int, norm_first: bool = True, add_prenet: bool = False, decoder_cls=TransformerEncoder, decoder_layer_cls=TransformerEncoderLayer, prefix_mode: int = 0, share_embedding: bool = True, nar_scale_factor: float = 1.0, prepend_bos: bool = False, num_quantizers: int = 8, **kwargs, ): """ Args: d_model: The number of expected features in the input (required). nhead: The number of heads in the multiheadattention models (required). num_layers: The number of sub-decoder-layers in the decoder (required). """ super().__init__() nar_d_model = int(d_model * nar_scale_factor) self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS) # ID NUM_AUDIO_TOKENS -> PAD # ID NUM_AUDIO_TOKENS + 1 -> BOS self.ar_audio_prepend_bos = prepend_bos self.ar_audio_embedding = TokenEmbedding( d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos) ) # PreNet if add_prenet: self.ar_text_prenet = nn.Sequential( Transpose(), nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), nn.BatchNorm1d(d_model), nn.ReLU(), nn.Dropout(0.5), nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), nn.BatchNorm1d(d_model), nn.ReLU(), nn.Dropout(0.5), nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), nn.BatchNorm1d(d_model), nn.ReLU(), nn.Dropout(0.5), Transpose(), nn.Linear(d_model, d_model), ) self.ar_audio_prenet = nn.Sequential( nn.Linear(d_model, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, d_model), ) else: self.ar_text_prenet = nn.Identity() self.ar_audio_prenet = nn.Identity() self.ar_text_position = SinePositionalEmbedding( d_model, dropout=0.1, scale=False, alpha=True, ) self.ar_audio_position = SinePositionalEmbedding( d_model, dropout=0.1, scale=False, alpha=True, ) self.ar_decoder = decoder_cls( decoder_layer_cls( d_model, nhead, dim_feedforward=d_model * 4, dropout=0.1, batch_first=True, norm_first=norm_first, ), num_layers=num_layers, norm=LayerNorm(d_model) if norm_first else None, ) self.ar_predict_layer = nn.Linear( d_model, NUM_AUDIO_TOKENS + 1, bias=False ) self.ar_accuracy_metric = MulticlassAccuracy( NUM_AUDIO_TOKENS + 1, top_k=10, average="micro", multidim_average="global", ignore_index=NUM_AUDIO_TOKENS, ) self.rng = random.Random(0) self.num_heads = nhead self.prefix_mode = prefix_mode self.num_quantizers = num_quantizers assert num_quantizers >= 1 if num_quantizers > 1: self.nar_audio_embeddings = nn.ModuleList( [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)] + [ TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS) for i in range(num_quantizers - 1) ] ) # W_a # PreNet if add_prenet: self.nar_text_prenet = nn.Sequential( Transpose(), nn.Conv1d( nar_d_model, nar_d_model, kernel_size=5, padding="same" ), nn.BatchNorm1d(nar_d_model), nn.ReLU(), nn.Dropout(0.5), nn.Conv1d( nar_d_model, nar_d_model, kernel_size=5, padding="same" ), nn.BatchNorm1d(nar_d_model), nn.ReLU(), nn.Dropout(0.5), nn.Conv1d( nar_d_model, nar_d_model, kernel_size=5, padding="same" ), nn.BatchNorm1d(nar_d_model), nn.ReLU(), nn.Dropout(0.5), Transpose(), nn.Linear(nar_d_model, nar_d_model), ) self.nar_audio_prenet = nn.Sequential( nn.Linear(nar_d_model, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, nar_d_model), ) else: self.nar_text_prenet = nn.Identity() self.nar_audio_prenet = nn.Identity() self.nar_text_position = SinePositionalEmbedding( nar_d_model, dropout=0.0, scale=False, alpha=False, ) self.nar_audio_position = SinePositionalEmbedding( nar_d_model, dropout=0.1, scale=False, alpha=False, ) self.nar_decoder = decoder_cls( decoder_layer_cls( nar_d_model, int(nhead * nar_scale_factor), dim_feedforward=nar_d_model * 4, dropout=0.1, batch_first=True, norm_first=norm_first, adaptive_layer_norm=True, ), num_layers=int(num_layers * nar_scale_factor), norm=AdaptiveLayerNorm( nar_d_model, norm=nn.LayerNorm(nar_d_model) ) if norm_first else None, ) self.nar_predict_layers = nn.ModuleList( [ nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False) for i in range(num_quantizers - 1) ] ) self.nar_stage_embeddings = nn.ModuleList( [ TokenEmbedding(nar_d_model, 1) for i in range(num_quantizers - 1) ] ) if share_embedding: # We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa # NOTE(Feiteng): In the experiment, this undermines accuracy # self.ar_predict_layer.weight = self.ar_audio_embedding.weight # We also share the parameters of the acoustic embedding layer and the output prediction layer, # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer. for j in range(0, num_quantizers - 2): self.nar_predict_layers[ j ].weight = self.nar_audio_embeddings[j + 2].weight self.nar_accuracy_metric = MulticlassAccuracy( NUM_AUDIO_TOKENS + 1, top_k=10, average="micro", multidim_average="global", ignore_index=NUM_AUDIO_TOKENS, ) def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: assert stage > 0 if stage == 1: for name, param in self.named_parameters(): if name.startswith("ar_"): print(f" AR parameter: {name}") yield param if stage == 2: for name, param in self.named_parameters(): if name.startswith("nar_"): print(f"NAR parameter: {name}") yield param def stage_named_parameters( self, stage: int = 1 ) -> Iterator[Tuple[str, nn.Parameter]]: assert stage > 0 if stage == 1: for pair in self.named_parameters(): if pair[0].startswith("ar_"): yield pair if stage == 2: for pair in self.named_parameters(): if pair[0].startswith("nar_"): yield pair def pad_y_eos(self, y, y_mask_int, eos_id): targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( y_mask_int, (0, 1), value=1 ) # inputs, targets if self.ar_audio_prepend_bos: return ( F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1), targets, ) return targets[:, :-1], targets[:, 1:] def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes): # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds # from the same utterance. # We implement this differently. if self.prefix_mode == 0: # no prefix prefix_len = 0 y_emb = self.nar_audio_embeddings[0](y) for j in range(1, nar_stage): # Formula (4) (5) y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) elif self.prefix_mode == 1: # prefix at begining int_low = (0.25 * y_lens.min()).type(torch.int64).item() prefix_len = torch.randint(int_low, int_low * 2, size=()).item() prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) for j in range(1, self.num_quantizers): y_prompts += self.nar_audio_embeddings[j]( codes[:, :prefix_len, j] ) if j < nar_stage: y_emb += self.nar_audio_embeddings[j]( codes[:, prefix_len:, j] ) y_emb = torch.concat([y_prompts, y_emb], axis=1) elif self.prefix_mode in [2, 4]: if self.prefix_mode == 2: # random prefix prefix_len = min(225, int(0.25 * y_lens.min().item())) y_prompts_codes = [] for b in range(codes.shape[0]): start = self.rng.randint(0, y_lens[b].item() - prefix_len) y_prompts_codes.append( torch.clone(codes[b, start : start + prefix_len]) ) codes[ b, start : start + prefix_len, nar_stage ] = NUM_AUDIO_TOKENS y_prompts_codes = torch.stack(y_prompts_codes, dim=0) else: prefix_len = y_prompts_codes.shape[1] y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) y_emb = self.nar_audio_embeddings[0](y) for j in range(1, self.num_quantizers): y_prompts += self.nar_audio_embeddings[j]( y_prompts_codes[..., j] ) if j < nar_stage: y_emb += self.nar_audio_embeddings[j](codes[..., j]) y_emb = torch.concat([y_prompts, y_emb], axis=1) else: raise ValueError return y_emb, prefix_len def visualize( self, predicts: Tuple[torch.Tensor], batch: Dict[str, Union[List, torch.Tensor]], output_dir: str, limit: int = 4, ) -> None: visualize(predicts, batch, output_dir, limit=limit) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, y: Union[torch.Tensor, PromptedFeatures], y_lens: Union[torch.Tensor, PromptedFeatures], reduction: str = "sum", train_stage: int = 0, **kwargs, ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: """ Args: x: A 2-D tensor of shape (N, S). x_lens: A 1-D tensor of shape (N,). It contains the number of tokens in `x` before padding. y: A 3-D tensor of shape (N, T, 8). y_lens: A 1-D tensor of shape (N,). It contains the number of tokens in `x` before padding. train_stage: 0: AR & NAR modules, 1: AR modules, 2: NAR modules Returns: Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. """ assert x.ndim == 2, x.shape assert x_lens.ndim == 1, x_lens.shape y_prompts_codes = None if isinstance(y, PromptedFeatures): y_prompts_codes, y = y.data prompts_len, y_lens = y_lens.data assert prompts_len.min() == prompts_len.max() assert self.prefix_mode == 4 y_prompts_codes = y_prompts_codes.type(torch.int64) assert y.ndim == 3, y.shape assert y_lens.ndim == 1, y_lens.shape # NOTE: x has been padded in TextTokenCollater x_mask = make_pad_mask(x_lens).to(x.device) y_mask = make_pad_mask(y_lens).to(y.device) y_mask_int = y_mask.type(torch.int64) text = x codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1)) y, targets = self.pad_y_eos( codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS ) x_len = x_lens.max() metrics = {} total_loss = 0.0 xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) if self.ar_audio_prepend_bos: ar_xy_padding_mask = torch.concat( [x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1 ) else: ar_xy_padding_mask = xy_padding_mask # AR Decoder if train_stage in [0, 1]: x = self.ar_text_embedding(text) x = self.ar_text_prenet(x) x = self.ar_text_position(x) y_len = y_lens.max() + int(self.ar_audio_prepend_bos) x_attn_mask = F.pad( torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), (0, y_len), value=True, ) y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1, ), (x_len, 0), value=False, ) xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) # merge key padding and attention masks bsz, src_len = x.shape[0], x_len + y_len _xy_padding_mask = ( ar_xy_padding_mask.view(bsz, 1, 1, src_len) .expand(-1, self.num_heads, -1, -1) .reshape(bsz * self.num_heads, 1, src_len) ) xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) xy_attn_mask = new_attn_mask y_emb = self.ar_audio_embedding(y) y_emb = self.ar_audio_prenet(y_emb) y_pos = self.ar_audio_position(y_emb) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.ar_decoder( (xy_pos, None), mask=xy_attn_mask, # src_key_padding_mask=xy_padding_mask, # is_causal=True, ) logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) # loss total_loss = F.cross_entropy(logits, targets, reduction=reduction) metrics["ArTop10Accuracy"] = self.ar_accuracy_metric( logits.detach(), targets ).item() * y_lens.sum().type(torch.float32) if self.num_quantizers == 1: return ((x, codes), total_loss, metrics) # Non-AR Decoders if self.ar_audio_prepend_bos: y = y[:, 1:] if train_stage in [0, 2]: num_nar_layers = self.num_quantizers - 1 nar_stage = self.rng.choices( [_k for _k in range(1, self.num_quantizers)], weights=[1.0 / num_nar_layers] * num_nar_layers, k=1, )[0] x = self.nar_text_embedding(text) x = self.nar_text_prenet(x) x = self.nar_text_position(x) y_emb, prefix_len = self._prepare_prompts( y, y_lens, codes, nar_stage, y_prompts_codes ) y_len = y_lens.max() targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int if self.prefix_mode in [2, 4]: xy_padding_mask = torch.concat( [ x_mask, F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False), ], dim=1, ) elif self.prefix_mode == 1: targets = targets[:, prefix_len:] y_pos = self.nar_audio_prenet(y_emb) y_pos = self.nar_audio_position(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight), src_key_padding_mask=xy_padding_mask, # is_causal=False, ) xy_dec = xy_dec[:, x_lens.max() + prefix_len :] if self.prefix_mode == 4: prefix_len = 0 # reset for Top10Accuracy metric logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute( 0, 2, 1 ) # loss total_length = (y_lens).sum().type(torch.float32) total_loss += ( F.cross_entropy( logits, targets, ignore_index=NUM_AUDIO_TOKENS, reduction=reduction, ) * (total_length / (total_length - prefix_len * x.shape[0])) ) metrics["NarTop10Accuracy"] = ( self.nar_accuracy_metric( F.pad( logits.detach(), (0, 0, 0, 1, 0, 0), value=logits.min().cpu().item(), ), targets, ).item() * total_length ) if train_stage == 0: total_loss = total_loss / 2.0 return ((x, codes), total_loss, metrics) def inference( self, x: torch.Tensor, x_lens: torch.Tensor, y: torch.Tensor, enroll_x_lens: torch.Tensor, top_k: int = -100, temperature: float = 1.0, top_p: float = 1.0, ) -> torch.Tensor: """ Args: x: A 2-D tensor of shape (1, S). x_lens: A 1-D tensor of shape (1,). It contains the number of tokens in `x` before padding. y: A 3-D tensor of shape (1, T, 8). top_k: (`optional`) int The number of highest probability tokens to keep for top-k-filtering. Default to -100. temperature: (`optional`) float The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. Returns: Return the predicted audio code matrix. """ assert x.ndim == 2, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.ndim == 3, y.shape assert y.shape[0] == 1, y.shape assert torch.all(x_lens > 0) # NOTE: x has been padded in TextTokenCollater text = x x = self.ar_text_embedding(text) x = self.ar_text_prenet(x) x = self.ar_text_position(x) text_len = x_lens.max() prompts = y prefix_len = y.shape[1] # AR Decoder # TODO: Managing decoder steps avoid repetitive computation y = prompts[..., 0] if self.ar_audio_prepend_bos: y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1) x_len = x_lens.max() x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) while True: y_emb = self.ar_audio_embedding(y) y_emb = self.ar_audio_prenet(y_emb) y_pos = self.ar_audio_position(y_emb) xy_pos = torch.concat([x, y_pos], dim=1) y_len = y.shape[1] x_attn_mask_pad = F.pad( x_attn_mask, (0, y_len), value=True, ) y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1 ), (x_len, 0), value=False, ) xy_attn_mask = torch.concat( [x_attn_mask_pad, y_attn_mask], dim=0 ).to(y.device) xy_dec, _ = self.ar_decoder( (xy_pos, None), mask=xy_attn_mask, ) logits = self.ar_predict_layer(xy_dec[:, -1]) ras=True samples = topk_sampling( logits, top_k=top_k, top_p=top_p, temperature=temperature, repetition_aware_sampling=ras, preceding_tokens=y ) if ( torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS or samples[0, 0] == NUM_AUDIO_TOKENS or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16 ): if prompts.shape[1] == y.shape[1]: raise SyntaxError( "well trained model shouldn't reach here." ) print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]") break y = torch.concat([y, samples], dim=1) codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] if self.num_quantizers == 1: return torch.stack(codes, dim=-1) # Non-AR Decoders y_emb = self.nar_audio_embeddings[0]( y[:, int(self.ar_audio_prepend_bos) :] ) if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes enrolled_len = enroll_x_lens.max().item() # SOS + Synthesis Text + EOS text = torch.concat( [ text[:, :1], text[:, enrolled_len - 1 :], ], dim=1, ) text_len = text_len - (enrolled_len - 2) assert text.shape[0] == 1 x = self.nar_text_embedding(text) x = self.nar_text_prenet(x) x = self.nar_text_position(x) if self.prefix_mode == 0: for i, (predict_layer, embedding_layer) in enumerate( zip( self.nar_predict_layers, self.nar_audio_embeddings[1:], ) ): y_pos = self.nar_audio_prenet(y_emb) y_pos = self.nar_audio_position(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[i].weight) ) logits = predict_layer(xy_dec[:, text_len + prefix_len :]) samples = torch.argmax(logits, dim=-1) codes.append(samples) if i < self.num_quantizers - 2: y_emb[:, :prefix_len] += embedding_layer( prompts[..., i + 1] ) y_emb[:, prefix_len:] += embedding_layer(samples) else: for j in range(1, self.num_quantizers): y_emb[:, :prefix_len] += self.nar_audio_embeddings[j]( prompts[..., j] ) for i, (predict_layer, embedding_layer) in enumerate( zip( self.nar_predict_layers, self.nar_audio_embeddings[1:], ) ): y_pos = self.nar_audio_prenet(y_emb) y_pos = self.nar_audio_position(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[i].weight) ) logits = predict_layer(xy_dec[:, text_len + prefix_len :]) samples = torch.argmax(logits, dim=-1) codes.append(samples) if i < self.num_quantizers - 2: y_emb[:, prefix_len:] += embedding_layer(samples) assert len(codes) == self.num_quantizers return torch.stack(codes, dim=-1) def continual( self, x: torch.Tensor, x_lens: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: """ Args: x: A 2-D tensor of shape (1, S). x_lens: A 1-D tensor of shape (1,). It contains the number of tokens in `x` before padding. y: A 3-D tensor of shape (1, T, 8). Returns: Return the predicted audio code matrix. """ assert x.ndim == 2, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.ndim == 3, y.shape assert y.shape[0] == 1, y.shape assert torch.all(x_lens > 0) assert self.num_quantizers == 8 # NOTE: x has been padded in TextTokenCollater text = x x = self.ar_text_embedding(text) x = self.ar_text_prenet(x) x = self.ar_text_position(x) text_len = x_lens.max() prefix_len = min(int(y.shape[1] * 0.5), 3 * 75) # AR Decoder prompts = y[:, :prefix_len] codes = [y[:, prefix_len:, 0]] # Non-AR Decoders x = self.nar_text_embedding(text) x = self.nar_text_prenet(x) x = self.nar_text_position(x) y_emb = self.nar_audio_embeddings[0](y[..., 0]) if self.prefix_mode == 0: for i, (predict_layer, embedding_layer) in enumerate( zip( self.nar_predict_layers, self.nar_audio_embeddings[1:], ) ): y_pos = self.nar_audio_position(y_emb) y_pos = self.nar_audio_prenet(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[i].weight) ) logits = predict_layer(xy_dec[:, text_len + prefix_len :]) samples = torch.argmax(logits, dim=-1) codes.append(samples) if i < 6: y_emb[:, :prefix_len] += embedding_layer( prompts[..., i + 1] ) y_emb[:, prefix_len:] += embedding_layer(samples) else: for j in range(1, 8): y_emb[:, :prefix_len] += self.nar_audio_embeddings[j]( prompts[..., j] ) for i, (predict_layer, embedding_layer) in enumerate( zip( self.nar_predict_layers, self.nar_audio_embeddings[1:], ) ): y_pos = self.nar_audio_prenet(y_emb) y_pos = self.nar_audio_position(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[i].weight) ) logits = predict_layer(xy_dec[:, text_len + prefix_len :]) samples = torch.argmax(logits, dim=-1) codes.append(samples) if i < 6: y_emb[:, prefix_len:] += embedding_layer(samples) assert len(codes) == 8 return torch.stack(codes, dim=-1) # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 ): """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) Make sure we keep at least min_tokens_to_keep per batch example in the output From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ if top_k > 0: top_k = min( max(top_k, min_tokens_to_keep), logits.size(-1) ) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum( F.softmax(sorted_logits, dim=-1), dim=-1 ) # Remove tokens with cumulative probability above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > top_p if min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ ..., :-1 ].clone() sorted_indices_to_remove[..., 0] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) logits[indices_to_remove] = filter_value return logits def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0, repetition_aware_sampling=False, preceding_tokens=None): # temperature: (`optional`) float # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. # top_k: (`optional`) int # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. # top_p: (`optional`) float # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. # Temperature (higher temperature => more likely to sample low probability tokens) if temperature != 1.0: logits = logits / temperature # Top-p/top-k filtering logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) # Sample probs = F.softmax(logits, dim=-1) # print top 10 value and index print("top 10 value and index", torch.topk(probs, 10), top_p) tokens = torch.multinomial(probs, num_samples=1) if repetition_aware_sampling: window_size = 10 threshold = 0.1 # we first generate the target code ct′ # by nucleus sampling with a pre-defined top-p value v. Then, we # calculate the repetition ratio r of token ct′ # in the preceding code sequence with a window size K. # If the ratio r exceeds a pre-defined repetition threshold ratio tn, we replace the target code ct′ # by # random sampling from p(ct′ # |x, c window_size: preceding_tokens = preceding_tokens[:, -window_size:] if preceding_tokens.shape[1] > 0: for i, item in enumerate(preceding_tokens): # check if the repeat ratio exceeds the threshold if (item == tokens[i]).sum() / window_size > threshold: # replace the target code ct′ by random sampling # make sure we don't sample the same token, by setting the probability of the token to 0 # logits[i][tokens[i]] = -float("Inf") probs = F.softmax(logits[i], dim=-1) token_new = torch.multinomial(probs, num_samples=1) print(f"Repetition Aware Sampling: {item}, {tokens[i]} -> {token_new}") print("probs", probs, logits.shape) tokens[i] = token_new else: print(f"Not trigger: {i}, {item}, {tokens[i]}") return tokens