# Copyright 2023 (authors: Feiteng Li) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import math import numbers import random from functools import partial from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn from tokenizer import TextTokenCollater from torch import Tensor from torch.nn import Linear, Module from torch.nn import functional as F from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear from torch.nn.parameter import Parameter from torchmetrics.classification import MulticlassAccuracy from icefall.utils import make_pad_mask NUM_TEXT_TOKENS = 5000 NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins class PromptedFeatures: def __init__(self, prompts, features): self.prompts = prompts self.features = features def to(self, device): return PromptedFeatures(self.prompts.to(device), self.features.to(device)) def sum(self): return self.features.sum() @property def ndim(self): return self.features.ndim @property def data(self): return (self.prompts, self.features) class TokenEmbedding(nn.Module): def __init__( self, dim_model: int, vocab_size: int, dropout: float = 0.0, ): super().__init__() self.vocab_size = vocab_size self.dim_model = dim_model self.dropout = torch.nn.Dropout(p=dropout) self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) @property def weight(self) -> torch.Tensor: return self.word_embeddings.weight def embedding(self, index: int) -> torch.Tensor: return self.word_embeddings.weight[index : index + 1] def forward(self, x: torch.Tensor): X = self.word_embeddings(x) X = self.dropout(X) return X class SinePositionalEmbedding(nn.Module): def __init__( self, dim_model: int, dropout: float = 0.0, scale: bool = False, alpha: bool = False, ): super().__init__() self.dim_model = dim_model self.x_scale = math.sqrt(dim_model) if scale else 1.0 self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) self.dropout = torch.nn.Dropout(p=dropout) self.reverse = False self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, 4000)) def extend_pe(self, x): """Reset the positional encodings.""" if self.pe is not None: if self.pe.size(1) >= x.size(1): if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device) return pe = torch.zeros(x.size(1), self.dim_model) if self.reverse: position = torch.arange( x.size(1) - 1, -1, -1.0, dtype=torch.float32 ).unsqueeze(1) else: position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.dim_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.dim_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.pe = pe.to(device=x.device, dtype=x.dtype).detach() def forward(self, x: torch.Tensor) -> torch.Tensor: self.extend_pe(x) output = x.unsqueeze(-1) if x.ndim == 2 else x output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] return self.dropout(output) class Transpose(nn.Identity): """(N, T, D) -> (N, D, T)""" def forward(self, input: torch.Tensor) -> torch.Tensor: return input.transpose(1, 2) _shape_t = Union[int, List[int], torch.Size] class MultiheadAttention(Module): r"""Allows the model to jointly attend to information from different representation subspaces as described in the paper: `Attention Is All You Need `_. Multi-Head Attention is defined as: .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. ``forward()`` will use a special optimized implementation if all of the following conditions are met: - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This restriction will be loosened in the future.) - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` - training is disabled (using ``.eval()``) - dropout is 0 - ``add_bias_kv`` is ``False`` - ``add_zero_attn`` is ``False`` - ``batch_first`` is ``True`` and the input is batched - ``kdim`` and ``vdim`` are equal to ``embed_dim`` - at most one of ``key_padding_mask`` or ``attn_mask`` is passed - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` nor ``attn_mask`` is passed If the optimized implementation is in use, a `NestedTensor `_ can be passed for ``query``/``key``/``value`` to represent padding more efficiently than using a padding mask. In this case, a `NestedTensor `_ will be returned, and an additional speedup proportional to the fraction of the input that is padding can be expected. Args: embed_dim: Total dimension of the model. num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). bias: If specified, adds bias to input / output projection layers. Default: ``True``. add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: ``False``. kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature). Examples:: >>> # xdoctest: +SKIP >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value) """ __constants__ = ["batch_first"] bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] def __init__( self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, linear1_cls=Linear, linear2_cls=Linear, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout self.batch_first = batch_first self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" if add_bias_kv: self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None if linear1_cls == Linear: if not self._qkv_same_embed_dim: self.q_proj_weight = Parameter( torch.empty((embed_dim, embed_dim), **factory_kwargs) ) self.k_proj_weight = Parameter( torch.empty((embed_dim, self.kdim), **factory_kwargs) ) self.v_proj_weight = Parameter( torch.empty((embed_dim, self.vdim), **factory_kwargs) ) self.register_parameter("in_proj_weight", None) else: self.in_proj_weight = Parameter( torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) ) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) if bias: self.in_proj_bias = Parameter( torch.empty(3 * embed_dim, **factory_kwargs) ) else: self.register_parameter("in_proj_bias", None) self.out_proj = NonDynamicallyQuantizableLinear( embed_dim, embed_dim, bias=bias, **factory_kwargs ) self._reset_parameters() else: if not self._qkv_same_embed_dim: raise NotImplementedError else: self.in_proj_linear = linear1_cls( embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs ) self.in_proj_weight = self.in_proj_linear.weight self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) if bias: self.in_proj_bias = self.in_proj_linear.bias else: self.register_parameter("in_proj_bias", None) self.out_proj = linear2_cls( embed_dim, embed_dim, bias=bias, **factory_kwargs ) if self.bias_k is not None: xavier_normal_(self.bias_k) if self.bias_v is not None: xavier_normal_(self.bias_v) self.add_zero_attn = add_zero_attn def _reset_parameters(self): if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight) else: xavier_uniform_(self.q_proj_weight) xavier_uniform_(self.k_proj_weight) xavier_uniform_(self.v_proj_weight) if self.in_proj_bias is not None: constant_(self.in_proj_bias, 0.0) constant_(self.out_proj.bias, 0.0) if self.bias_k is not None: xavier_normal_(self.bias_k) if self.bias_v is not None: xavier_normal_(self.bias_v) def __setstate__(self, state): # Support loading old MultiheadAttention checkpoints generated by v1.1.0 if "_qkv_same_embed_dim" not in state: state["_qkv_same_embed_dim"] = True super(MultiheadAttention, self).__setstate__(state) def forward( self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against key-value pairs to produce the output. See "Attention Is All You Need" for more details. key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details. value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details. key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and byte masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. Default: ``True``. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight. average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) Outputs: - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the embedding dimension ``embed_dim``. - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. .. note:: `batch_first` argument is ignored for unbatched inputs. """ is_batched = query.dim() == 3 if key_padding_mask is not None: _kpm_dtype = key_padding_mask.dtype if _kpm_dtype != torch.bool and not torch.is_floating_point( key_padding_mask ): raise AssertionError( "only bool and floating types of key_padding_mask are supported" ) why_not_fast_path = "" if not is_batched: why_not_fast_path = ( f"input not batched; expected query.dim() of 3 but got {query.dim()}" ) elif query is not key or key is not value: # When lifting this restriction, don't forget to either # enforce that the dtypes all match or test cases where # they don't! why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" elif ( self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype ): # this case will fail anyway, but at least they'll get a useful error message. why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" elif self.training: why_not_fast_path = "training is enabled" elif not self.batch_first: why_not_fast_path = "batch_first was not True" elif self.bias_k is not None: why_not_fast_path = "self.bias_k was not None" elif self.bias_v is not None: why_not_fast_path = "self.bias_v was not None" elif self.dropout: why_not_fast_path = f"dropout was {self.dropout}, required zero" elif self.add_zero_attn: why_not_fast_path = "add_zero_attn was enabled" elif not self._qkv_same_embed_dim: why_not_fast_path = "_qkv_same_embed_dim was not True" elif attn_mask is not None: why_not_fast_path = "attn_mask was not None" elif query.is_nested and key_padding_mask is not None: why_not_fast_path = ( "key_padding_mask is not supported with NestedTensor input" ) elif self.num_heads % 2 == 1: why_not_fast_path = "num_heads is odd" elif torch.is_autocast_enabled(): why_not_fast_path = "autocast is enabled" if not why_not_fast_path: tensor_args = ( query, key, value, self.in_proj_weight, self.in_proj_bias, self.out_proj.weight, self.out_proj.bias, ) # We have to use list comprehensions below because TorchScript does not support # generator expressions. if torch.overrides.has_torch_function(tensor_args): why_not_fast_path = "some Tensor argument has_torch_function" elif not all( [ (x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args ] ): why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" elif torch.is_grad_enabled() and any( [x is not None and x.requires_grad for x in tensor_args] ): why_not_fast_path = ( "grad is enabled and at least one of query or the " "input/output projection weights or biases requires_grad" ) if not why_not_fast_path: return torch._native_multi_head_attention( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.out_proj.weight, self.out_proj.bias, key_padding_mask if key_padding_mask is not None else attn_mask, need_weights, average_attn_weights, 1 if key_padding_mask is not None else 0 if attn_mask is not None else None, ) any_nested = query.is_nested or key.is_nested or value.is_nested assert not any_nested, ( "MultiheadAttention does not support NestedTensor outside of its fast path. " + f"The fast path was not hit because {why_not_fast_path}" ) if self.batch_first and is_batched: # make sure that the transpose op does not affect the "is" property if key is value: if query is key: query = key = value = query.transpose(1, 0) else: query, key = [x.transpose(1, 0) for x in (query, key)] value = key else: query, key, value = [x.transpose(1, 0) for x in (query, key, value)] if not self._qkv_same_embed_dim: attn_output, attn_output_weights = F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights, ) else: attn_output, attn_output_weights = F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, average_attn_weights=average_attn_weights, ) if self.batch_first and is_batched: return attn_output.transpose(1, 0), attn_output_weights else: return attn_output, attn_output_weights class LayerNorm(nn.Module): __constants__ = ["normalized_shape", "eps", "elementwise_affine"] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool def __init__( self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): # mypy error: incompatible types in assignment normalized_shape = (normalized_shape,) # type: ignore[assignment] self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter( torch.empty(self.normalized_shape, **factory_kwargs) ) self.bias = nn.Parameter( torch.empty(self.normalized_shape, **factory_kwargs) ) else: self.register_parameter("weight", None) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: if self.elementwise_affine: nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, input: Tensor, embedding: Any = None) -> Tensor: if isinstance(input, tuple): input, embedding = input return ( F.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps, ), embedding, ) assert embedding is None return F.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps ) def extra_repr(self) -> str: return ( "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) ) class AdaptiveLayerNorm(nn.Module): r"""Adaptive Layer Normalization""" def __init__(self, d_model, norm) -> None: super(AdaptiveLayerNorm, self).__init__() self.project_layer = nn.Linear(d_model, 2 * d_model) self.norm = norm self.d_model = d_model self.eps = self.norm.eps def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: if isinstance(input, tuple): input, embedding = input weight, bias = torch.split( self.project_layer(embedding), split_size_or_sections=self.d_model, dim=-1, ) return (weight * self.norm(input) + bias, embedding) weight, bias = torch.split( self.project_layer(embedding), split_size_or_sections=self.d_model, dim=-1, ) return weight * self.norm(input) + bias class TransformerEncoderLayer(nn.Module): __constants__ = ["batch_first", "norm_first"] def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None, linear1_self_attention_cls: nn.Module = nn.Linear, linear2_self_attention_cls: nn.Module = nn.Linear, linear1_feedforward_cls: nn.Module = nn.Linear, linear2_feedforward_cls: nn.Module = nn.Linear, layer_norm_cls: nn.Module = LayerNorm, layer_norm_eps: float = 1e-5, adaptive_layer_norm=False, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, linear1_cls=linear1_self_attention_cls, linear2_cls=linear2_self_attention_cls, **factory_kwargs, ) # Implementation of Feedforward model self.linear1 = linear1_feedforward_cls( d_model, dim_feedforward, **factory_kwargs ) self.dropout = nn.Dropout(dropout) self.linear2 = linear2_feedforward_cls( dim_feedforward, d_model, **factory_kwargs ) self.norm_first = norm_first self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str): activation = _get_activation_fn(activation) elif isinstance(activation, partial): activation = activation(d_model) # elif activation == BalancedDoubleSwish: # activation = BalancedDoubleSwish(d_model) # # We can't test self.activation in forward() in TorchScript, # # so stash some information about it instead. # if activation is F.relu or isinstance(activation, torch.nn.ReLU): # self.activation_relu_or_gelu = 1 # elif activation is F.gelu or isinstance(activation, torch.nn.GELU): # self.activation_relu_or_gelu = 2 # else: # self.activation_relu_or_gelu = 0 self.activation = activation norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) # if layer_norm_cls == IdentityNorm: # norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) # else: if True: norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) if adaptive_layer_norm: self.norm1 = AdaptiveLayerNorm(d_model, norm1) self.norm2 = AdaptiveLayerNorm(d_model, norm2) else: self.norm1 = norm1 self.norm2 = norm2 def __setstate__(self, state): super(TransformerEncoderLayer, self).__setstate__(state) if not hasattr(self, "activation"): self.activation = F.relu def forward( self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ x, stage_embedding = src, None is_src_tuple = False if isinstance(src, tuple): x, stage_embedding = src is_src_tuple = True if src_key_padding_mask is not None: _skpm_dtype = src_key_padding_mask.dtype if _skpm_dtype != torch.bool and not torch.is_floating_point( src_key_padding_mask ): raise AssertionError( "only bool and floating types of key_padding_mask are supported" ) if self.norm_first: x = x + self._sa_block( self.norm1(x, stage_embedding), src_mask, src_key_padding_mask, ) x = x + self._ff_block(self.norm2(x, stage_embedding)) else: x = self.norm1( x + self._sa_block(x, src_mask, src_key_padding_mask), stage_embedding, ) x = self.norm2(x + self._ff_block(x), stage_embedding) if is_src_tuple: return (x, stage_embedding) return x # self-attention block def _sa_block( self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], ) -> Tensor: x = self.self_attn( x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0] return self.dropout1(x) # feed forward block def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x) class TransformerEncoder(nn.Module): r"""TransformerEncoder is a stack of N encoder layers. Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. Args: encoder_layer: an instance of the TransformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). norm: the layer normalization component (optional). enable_nested_tensor: if True, input will automatically convert to nested tensor (and convert back on output). This will improve the overall performance of TransformerEncoder when padding rate is high. Default: ``True`` (enabled). Examples:: >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> out = transformer_encoder(src) """ __constants__ = ["norm"] def __init__(self, encoder_layer, num_layers, norm=None): super(TransformerEncoder, self).__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, return_layer_states: bool = False, ) -> Tensor: r"""Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). return_layer_states: return layers' state (optional). Shape: see the docs in Transformer class. """ if return_layer_states: layer_states = [] # layers' output output = src for mod in self.layers: output = mod( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, ) layer_states.append(output[0]) if self.norm is not None: output = self.norm(output) return layer_states, output output = src for mod in self.layers: output = mod( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask ) if self.norm is not None: output = self.norm(output) return output def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class VALLE(nn.Module): """It implements https://arxiv.org/abs/2301.02111 "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers" """ def __init__( self, d_model: int, nhead: int, num_layers: int, norm_first: bool = True, add_prenet: bool = False, decoder_cls=TransformerEncoder, decoder_layer_cls=TransformerEncoderLayer, prefix_mode: int = 0, share_embedding: bool = True, nar_scale_factor: float = 1.0, prepend_bos: bool = False, num_quantizers: int = 8, **kwargs, ): """ Args: d_model: The number of expected features in the input (required). nhead: The number of heads in the multiheadattention models (required). num_layers: The number of sub-decoder-layers in the decoder (required). """ super().__init__() nar_d_model = int(d_model * nar_scale_factor) self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS) # ID NUM_AUDIO_TOKENS -> PAD # ID NUM_AUDIO_TOKENS + 1 -> BOS self.ar_audio_prepend_bos = prepend_bos self.ar_audio_embedding = TokenEmbedding( d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos) ) # PreNet if add_prenet: self.ar_text_prenet = nn.Sequential( Transpose(), nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), nn.BatchNorm1d(d_model), nn.ReLU(), nn.Dropout(0.5), nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), nn.BatchNorm1d(d_model), nn.ReLU(), nn.Dropout(0.5), nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), nn.BatchNorm1d(d_model), nn.ReLU(), nn.Dropout(0.5), Transpose(), nn.Linear(d_model, d_model), ) self.ar_audio_prenet = nn.Sequential( nn.Linear(d_model, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, d_model), ) else: self.ar_text_prenet = nn.Identity() self.ar_audio_prenet = nn.Identity() self.ar_text_position = SinePositionalEmbedding( d_model, dropout=0.1, scale=False, alpha=True, ) self.ar_audio_position = SinePositionalEmbedding( d_model, dropout=0.1, scale=False, alpha=True, ) self.ar_decoder = decoder_cls( decoder_layer_cls( d_model, nhead, dim_feedforward=d_model * 4, dropout=0.1, batch_first=True, norm_first=norm_first, ), num_layers=num_layers, norm=LayerNorm(d_model) if norm_first else None, ) self.ar_predict_layer = nn.Linear(d_model, NUM_AUDIO_TOKENS + 1, bias=False) self.ar_accuracy_metric = MulticlassAccuracy( NUM_AUDIO_TOKENS + 1, top_k=10, average="micro", multidim_average="global", ignore_index=NUM_AUDIO_TOKENS, ) self.rng = random.Random(0) self.num_heads = nhead self.prefix_mode = prefix_mode self.num_quantizers = num_quantizers assert num_quantizers >= 1 if num_quantizers > 1: self.nar_audio_embeddings = nn.ModuleList( [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)] + [ TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS) for i in range(num_quantizers - 1) ] ) # W_a # PreNet if add_prenet: self.nar_text_prenet = nn.Sequential( Transpose(), nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), nn.BatchNorm1d(nar_d_model), nn.ReLU(), nn.Dropout(0.5), nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), nn.BatchNorm1d(nar_d_model), nn.ReLU(), nn.Dropout(0.5), nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), nn.BatchNorm1d(nar_d_model), nn.ReLU(), nn.Dropout(0.5), Transpose(), nn.Linear(nar_d_model, nar_d_model), ) self.nar_audio_prenet = nn.Sequential( nn.Linear(nar_d_model, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, nar_d_model), ) else: self.nar_text_prenet = nn.Identity() self.nar_audio_prenet = nn.Identity() self.nar_text_position = SinePositionalEmbedding( nar_d_model, dropout=0.0, scale=False, alpha=False, ) self.nar_audio_position = SinePositionalEmbedding( nar_d_model, dropout=0.1, scale=False, alpha=False, ) self.nar_decoder = decoder_cls( decoder_layer_cls( nar_d_model, int(nhead * nar_scale_factor), dim_feedforward=nar_d_model * 4, dropout=0.1, batch_first=True, norm_first=norm_first, adaptive_layer_norm=True, ), num_layers=int(num_layers * nar_scale_factor), norm=AdaptiveLayerNorm(nar_d_model, norm=nn.LayerNorm(nar_d_model)) if norm_first else None, ) self.nar_predict_layers = nn.ModuleList( [ nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False) for i in range(num_quantizers - 1) ] ) self.nar_stage_embeddings = nn.ModuleList( [TokenEmbedding(nar_d_model, 1) for i in range(num_quantizers - 1)] ) if share_embedding: # We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa # NOTE(Feiteng): In the experiment, this undermines accuracy # self.ar_predict_layer.weight = self.ar_audio_embedding.weight # We also share the parameters of the acoustic embedding layer and the output prediction layer, # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer. for j in range(0, num_quantizers - 2): self.nar_predict_layers[j].weight = self.nar_audio_embeddings[ j + 2 ].weight self.nar_accuracy_metric = MulticlassAccuracy( NUM_AUDIO_TOKENS + 1, top_k=10, average="micro", multidim_average="global", ignore_index=NUM_AUDIO_TOKENS, ) def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: assert stage > 0 if stage == 1: for name, param in self.named_parameters(): if name.startswith("ar_"): print(f" AR parameter: {name}") yield param if stage == 2: for name, param in self.named_parameters(): if name.startswith("nar_"): print(f"NAR parameter: {name}") yield param def stage_named_parameters( self, stage: int = 1 ) -> Iterator[Tuple[str, nn.Parameter]]: assert stage > 0 if stage == 1: for pair in self.named_parameters(): if pair[0].startswith("ar_"): yield pair if stage == 2: for pair in self.named_parameters(): if pair[0].startswith("nar_"): yield pair def pad_y_eos(self, y, y_mask_int, eos_id): targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( y_mask_int, (0, 1), value=1 ) # inputs, targets if self.ar_audio_prepend_bos: return ( F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1), targets, ) return targets[:, :-1], targets[:, 1:] def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes): # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds # from the same utterance. # We implement this differently. if self.prefix_mode == 0: # no prefix prefix_len = 0 y_emb = self.nar_audio_embeddings[0](y) for j in range(1, nar_stage): # Formula (4) (5) y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) elif self.prefix_mode == 1: # prefix at begining int_low = (0.25 * y_lens.min()).type(torch.int64).item() prefix_len = torch.randint(int_low, int_low * 2, size=()).item() prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) for j in range(1, self.num_quantizers): y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j]) if j < nar_stage: y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j]) y_emb = torch.concat([y_prompts, y_emb], axis=1) elif self.prefix_mode in [2, 4]: if self.prefix_mode == 2: # random prefix prefix_len = min(225, int(0.25 * y_lens.min().item())) y_prompts_codes = [] for b in range(codes.shape[0]): start = self.rng.randint(0, y_lens[b].item() - prefix_len) y_prompts_codes.append( torch.clone(codes[b, start : start + prefix_len]) ) codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS y_prompts_codes = torch.stack(y_prompts_codes, dim=0) else: prefix_len = y_prompts_codes.shape[1] y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) y_emb = self.nar_audio_embeddings[0](y) for j in range(1, self.num_quantizers): y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j]) if j < nar_stage: y_emb += self.nar_audio_embeddings[j](codes[..., j]) y_emb = torch.concat([y_prompts, y_emb], axis=1) else: raise ValueError return y_emb, prefix_len def forward( self, x: torch.Tensor, x_lens: torch.Tensor, y: Union[torch.Tensor, PromptedFeatures], y_lens: Union[torch.Tensor, PromptedFeatures], reduction: str = "sum", train_stage: int = 0, **kwargs, ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: """ Args: x: A 2-D tensor of shape (N, S). x_lens: A 1-D tensor of shape (N,). It contains the number of tokens in `x` before padding. y: A 3-D tensor of shape (N, T, 8). y_lens: A 1-D tensor of shape (N,). It contains the number of tokens in `x` before padding. train_stage: 0: AR & NAR modules, 1: AR modules, 2: NAR modules Returns: Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. """ assert x.ndim == 2, x.shape assert x_lens.ndim == 1, x_lens.shape y_prompts_codes = None if isinstance(y, PromptedFeatures): y_prompts_codes, y = y.data prompts_len, y_lens = y_lens.data assert prompts_len.min() == prompts_len.max() assert self.prefix_mode == 4 y_prompts_codes = y_prompts_codes.type(torch.int64) assert y.ndim == 3, y.shape assert y_lens.ndim == 1, y_lens.shape # NOTE: x has been padded in TextTokenCollater x_mask = make_pad_mask(x_lens).to(x.device) y_mask = make_pad_mask(y_lens).to(y.device) y_mask_int = y_mask.type(torch.int64) text = x codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1)) y, targets = self.pad_y_eos(codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS) x_len = x_lens.max() metrics = {} total_loss = 0.0 xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) if self.ar_audio_prepend_bos: ar_xy_padding_mask = torch.concat( [x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1 ) else: ar_xy_padding_mask = xy_padding_mask # AR Decoder if train_stage in [0, 1]: x = self.ar_text_embedding(text) x = self.ar_text_prenet(x) x = self.ar_text_position(x) y_len = y_lens.max() + int(self.ar_audio_prepend_bos) x_attn_mask = F.pad( torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), (0, y_len), value=True, ) y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1, ), (x_len, 0), value=False, ) xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) # merge key padding and attention masks bsz, src_len = x.shape[0], x_len + y_len _xy_padding_mask = ( ar_xy_padding_mask.view(bsz, 1, 1, src_len) .expand(-1, self.num_heads, -1, -1) .reshape(bsz * self.num_heads, 1, src_len) ) xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) xy_attn_mask = new_attn_mask y_emb = self.ar_audio_embedding(y) y_emb = self.ar_audio_prenet(y_emb) y_pos = self.ar_audio_position(y_emb) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.ar_decoder( (xy_pos, None), mask=xy_attn_mask, # src_key_padding_mask=xy_padding_mask, # is_causal=True, ) logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) # loss total_loss = F.cross_entropy(logits, targets, reduction=reduction) metrics["ArTop10Accuracy"] = self.ar_accuracy_metric( logits.detach(), targets ).item() * y_lens.sum().type(torch.float32) if self.num_quantizers == 1: return ((x, codes), total_loss, metrics) # Non-AR Decoders if self.ar_audio_prepend_bos: y = y[:, 1:] if train_stage in [0, 2]: num_nar_layers = self.num_quantizers - 1 nar_stage = self.rng.choices( [_k for _k in range(1, self.num_quantizers)], weights=[1.0 / num_nar_layers] * num_nar_layers, k=1, )[0] x = self.nar_text_embedding(text) x = self.nar_text_prenet(x) x = self.nar_text_position(x) y_emb, prefix_len = self._prepare_prompts( y, y_lens, codes, nar_stage, y_prompts_codes ) y_len = y_lens.max() targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int if self.prefix_mode in [2, 4]: xy_padding_mask = torch.concat( [ x_mask, F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False), ], dim=1, ) elif self.prefix_mode == 1: targets = targets[:, prefix_len:] y_pos = self.nar_audio_prenet(y_emb) y_pos = self.nar_audio_position(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight), src_key_padding_mask=xy_padding_mask, # is_causal=False, ) xy_dec = xy_dec[:, x_lens.max() + prefix_len :] if self.prefix_mode == 4: prefix_len = 0 # reset for Top10Accuracy metric logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1) # loss total_length = (y_lens).sum().type(torch.float32) total_loss += F.cross_entropy( logits, targets, ignore_index=NUM_AUDIO_TOKENS, reduction=reduction, ) * (total_length / (total_length - prefix_len * x.shape[0])) metrics["NarTop10Accuracy"] = ( self.nar_accuracy_metric( F.pad( logits.detach(), (0, 0, 0, 1, 0, 0), value=logits.min().cpu().item(), ), targets, ).item() * total_length ) if train_stage == 0: total_loss = total_loss / 2.0 return ((x, codes), total_loss, metrics) def inference( self, x: torch.Tensor, x_lens: torch.Tensor, y: torch.Tensor, enroll_x_lens: torch.Tensor, top_k: int = -100, temperature: float = 1.0, top_p: float = 1.0, ras: bool = False, ) -> torch.Tensor: """ Args: x: A 2-D tensor of shape (1, S). x_lens: A 1-D tensor of shape (1,). It contains the number of tokens in `x` before padding. y: A 3-D tensor of shape (1, T, 8). top_k: (`optional`) int The number of highest probability tokens to keep for top-k-filtering. Default to -100. temperature: (`optional`) float The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. ras: (`optional`) bool Whether to use repetition-aware sampling. Default to False. Returns: Return the predicted audio code matrix. """ assert x.ndim == 2, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.ndim == 3, y.shape assert y.shape[0] == 1, y.shape assert torch.all(x_lens > 0) # NOTE: x has been padded in TextTokenCollater text = x x = self.ar_text_embedding(text) x = self.ar_text_prenet(x) x = self.ar_text_position(x) text_len = x_lens.max() prompts = y prefix_len = y.shape[1] # AR Decoder # TODO: Managing decoder steps avoid repetitive computation y = prompts[..., 0] if self.ar_audio_prepend_bos: y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1) x_len = x_lens.max() x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) while True: y_emb = self.ar_audio_embedding(y) y_emb = self.ar_audio_prenet(y_emb) y_pos = self.ar_audio_position(y_emb) xy_pos = torch.concat([x, y_pos], dim=1) y_len = y.shape[1] x_attn_mask_pad = F.pad( x_attn_mask, (0, y_len), value=True, ) y_attn_mask = F.pad( torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), value=False, ) xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( y.device ) xy_dec, _ = self.ar_decoder( (xy_pos, None), mask=xy_attn_mask, ) logits = self.ar_predict_layer(xy_dec[:, -1]) samples = topk_sampling( logits, top_k=top_k, top_p=top_p, temperature=temperature, repetition_aware_sampling=ras, preceding_tokens=y, ) if ( torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS or samples[0, 0] == NUM_AUDIO_TOKENS or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16 ): if prompts.shape[1] == y.shape[1]: raise SyntaxError("well trained model shouldn't reach here.") break y = torch.concat([y, samples], dim=1) codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] if self.num_quantizers == 1: return torch.stack(codes, dim=-1) # Non-AR Decoders y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :]) if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes enrolled_len = enroll_x_lens.max().item() # SOS + Synthesis Text + EOS text = torch.concat( [ text[:, :1], text[:, enrolled_len - 1 :], ], dim=1, ) text_len = text_len - (enrolled_len - 2) assert text.shape[0] == 1 x = self.nar_text_embedding(text) x = self.nar_text_prenet(x) x = self.nar_text_position(x) if self.prefix_mode == 0: for i, (predict_layer, embedding_layer) in enumerate( zip( self.nar_predict_layers, self.nar_audio_embeddings[1:], ) ): y_pos = self.nar_audio_prenet(y_emb) y_pos = self.nar_audio_position(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[i].weight) ) logits = predict_layer(xy_dec[:, text_len + prefix_len :]) samples = torch.argmax(logits, dim=-1) codes.append(samples) if i < self.num_quantizers - 2: y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1]) y_emb[:, prefix_len:] += embedding_layer(samples) else: for j in range(1, self.num_quantizers): y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j]) for i, (predict_layer, embedding_layer) in enumerate( zip( self.nar_predict_layers, self.nar_audio_embeddings[1:], ) ): y_pos = self.nar_audio_prenet(y_emb) y_pos = self.nar_audio_position(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[i].weight) ) logits = predict_layer(xy_dec[:, text_len + prefix_len :]) samples = torch.argmax(logits, dim=-1) codes.append(samples) if i < self.num_quantizers - 2: y_emb[:, prefix_len:] += embedding_layer(samples) assert len(codes) == self.num_quantizers return torch.stack(codes, dim=-1) def continual( self, x: torch.Tensor, x_lens: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: """ Args: x: A 2-D tensor of shape (1, S). x_lens: A 1-D tensor of shape (1,). It contains the number of tokens in `x` before padding. y: A 3-D tensor of shape (1, T, 8). Returns: Return the predicted audio code matrix. """ assert x.ndim == 2, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.ndim == 3, y.shape assert y.shape[0] == 1, y.shape assert torch.all(x_lens > 0) assert self.num_quantizers == 8 # NOTE: x has been padded in TextTokenCollater text = x x = self.ar_text_embedding(text) x = self.ar_text_prenet(x) x = self.ar_text_position(x) text_len = x_lens.max() prefix_len = min(int(y.shape[1] * 0.5), 3 * 75) # AR Decoder prompts = y[:, :prefix_len] codes = [y[:, prefix_len:, 0]] # Non-AR Decoders x = self.nar_text_embedding(text) x = self.nar_text_prenet(x) x = self.nar_text_position(x) y_emb = self.nar_audio_embeddings[0](y[..., 0]) if self.prefix_mode == 0: for i, (predict_layer, embedding_layer) in enumerate( zip( self.nar_predict_layers, self.nar_audio_embeddings[1:], ) ): y_pos = self.nar_audio_position(y_emb) y_pos = self.nar_audio_prenet(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[i].weight) ) logits = predict_layer(xy_dec[:, text_len + prefix_len :]) samples = torch.argmax(logits, dim=-1) codes.append(samples) if i < 6: y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1]) y_emb[:, prefix_len:] += embedding_layer(samples) else: for j in range(1, 8): y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j]) for i, (predict_layer, embedding_layer) in enumerate( zip( self.nar_predict_layers, self.nar_audio_embeddings[1:], ) ): y_pos = self.nar_audio_prenet(y_emb) y_pos = self.nar_audio_position(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[i].weight) ) logits = predict_layer(xy_dec[:, text_len + prefix_len :]) samples = torch.argmax(logits, dim=-1) codes.append(samples) if i < 6: y_emb[:, prefix_len:] += embedding_layer(samples) assert len(codes) == 8 return torch.stack(codes, dim=-1) def visualize( self, predicts: Tuple[torch.Tensor], batch: Dict[str, Union[List, torch.Tensor]], tokenizer: TextTokenCollater, output_dir: str, limit: int = 4, ) -> None: audio_features = batch["features"].to("cpu").detach().numpy() audio_features_lens = batch["features_lens"].to("cpu").detach().numpy() tokens = batch["tokens"] text_tokens, text_tokens_lens = tokenizer(tokens) assert text_tokens.ndim == 2 texts = batch["text"] utt_ids = [cut.id for cut in batch["cut"]] encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy() decoder_outputs = predicts[1] if isinstance(decoder_outputs, list): decoder_outputs = decoder_outputs[-1] decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy() vmin, vmax = 0, 1024 # Encodec num_figures = 3 for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])): _ = plt.figure(figsize=(14, 8 * num_figures)) S = text_tokens_lens[b] T = audio_features_lens[b] # encoder plt.subplot(num_figures, 1, 1) plt.title(f"Text: {text}") plt.imshow( X=np.transpose(encoder_outputs[b]), cmap=plt.get_cmap("jet"), aspect="auto", interpolation="nearest", ) plt.gca().invert_yaxis() plt.axvline(x=S - 0.4, linewidth=2, color="r") plt.xlabel("Encoder Output") plt.colorbar() # decoder plt.subplot(num_figures, 1, 2) plt.imshow( X=np.transpose(decoder_outputs[b]), cmap=plt.get_cmap("jet"), aspect="auto", interpolation="nearest", vmin=vmin, vmax=vmax, ) plt.gca().invert_yaxis() plt.axvline(x=T - 0.4, linewidth=2, color="r") plt.xlabel("Decoder Output") plt.colorbar() # target plt.subplot(num_figures, 1, 3) plt.imshow( X=np.transpose(audio_features[b]), cmap=plt.get_cmap("jet"), aspect="auto", interpolation="nearest", vmin=vmin, vmax=vmax, ) plt.gca().invert_yaxis() plt.axvline(x=T - 0.4, linewidth=2, color="r") plt.xlabel("Decoder Target") plt.colorbar() plt.savefig(f"{output_dir}/{utt_id}.png") plt.close() # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 ): """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) Make sure we keep at least min_tokens_to_keep per batch example in the output From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ if top_k > 0: top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > top_p if min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) logits[indices_to_remove] = filter_value return logits def topk_sampling( logits, top_k=10, top_p=1.0, temperature=1.0, repetition_aware_sampling=False, preceding_tokens=None, ): if temperature != 1.0: logits = logits / temperature # Top-p/top-k filtering logits_filtered = top_k_top_p_filtering( logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2 ) # Sample probs = F.softmax(logits_filtered, dim=-1) tokens = torch.multinomial(probs, num_samples=1) if repetition_aware_sampling: window_size = 10 threshold = 0.1 # we first generate the target code ct′ # by nucleus sampling with a pre-defined top-p value v. Then, we # calculate the repetition ratio r of token ct′ # in the preceding code sequence with a window size K. # If the ratio r exceeds a pre-defined repetition threshold ratio tn, we replace the target code ct′ # by # random sampling from p(ct′ # |x, c window_size: preceding_tokens = preceding_tokens[:, -window_size:] if preceding_tokens.shape[1] > 0: for i, item in enumerate(preceding_tokens): # check if the repeat ratio exceeds the threshold if (item == tokens[i]).sum() / window_size > threshold: # replace the target code ct′ by random sampling probs = F.softmax(logits[i], dim=-1) token_new = torch.multinomial(probs, num_samples=1) tokens[i] = token_new return tokens