1732 lines
65 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
import numbers
import random
from functools import partial
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from tokenizer import TextTokenCollater
from torch import Tensor
from torch.nn import Linear, Module
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from torchmetrics.classification import MulticlassAccuracy
from icefall.utils import make_pad_mask
NUM_TEXT_TOKENS = 5000
NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
class PromptedFeatures:
def __init__(self, prompts, features):
self.prompts = prompts
self.features = features
def to(self, device):
return PromptedFeatures(self.prompts.to(device), self.features.to(device))
def sum(self):
return self.features.sum()
@property
def ndim(self):
return self.features.ndim
@property
def data(self):
return (self.prompts, self.features)
class TokenEmbedding(nn.Module):
def __init__(
self,
dim_model: int,
vocab_size: int,
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
self.dim_model = dim_model
self.dropout = torch.nn.Dropout(p=dropout)
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
@property
def weight(self) -> torch.Tensor:
return self.word_embeddings.weight
def embedding(self, index: int) -> torch.Tensor:
return self.word_embeddings.weight[index : index + 1]
def forward(self, x: torch.Tensor):
X = self.word_embeddings(x)
X = self.dropout(X)
return X
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
dim_model: int,
dropout: float = 0.0,
scale: bool = False,
alpha: bool = False,
):
super().__init__()
self.dim_model = dim_model
self.x_scale = math.sqrt(dim_model) if scale else 1.0
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
self.dropout = torch.nn.Dropout(p=dropout)
self.reverse = False
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, 4000))
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.dim_model)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.dim_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.dim_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.extend_pe(x)
output = x.unsqueeze(-1) if x.ndim == 2 else x
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
return self.dropout(output)
class Transpose(nn.Identity):
"""(N, T, D) -> (N, D, T)"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input.transpose(1, 2)
_shape_t = Union[int, List[int], torch.Size]
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information
from different representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
Multi-Head Attention is defined as:
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
``forward()`` will use a special optimized implementation if all of the following
conditions are met:
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
restriction will be loosened in the future.)
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
- training is disabled (using ``.eval()``)
- dropout is 0
- ``add_bias_kv`` is ``False``
- ``add_zero_attn`` is ``False``
- ``batch_first`` is ``True`` and the input is batched
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
nor ``attn_mask`` is passed
If the optimized implementation is in use, a
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
``query``/``key``/``value`` to represent padding more efficiently than using a
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
will be returned, and an additional speedup proportional to the fraction of the input
that is padding can be expected.
Args:
embed_dim: Total dimension of the model.
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
Default: ``False``.
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
Examples::
>>> # xdoctest: +SKIP
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
__constants__ = ["batch_first"]
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
linear1_cls=Linear,
linear2_cls=Linear,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
else:
self.bias_k = self.bias_v = None
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs)
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs)
)
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs)
)
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
self._reset_parameters()
else:
if not self._qkv_same_embed_dim:
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
)
self.in_proj_weight = self.in_proj_linear.weight
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = self.in_proj_linear.bias
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
self.add_zero_attn = add_zero_attn
def _reset_parameters(self):
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.0)
constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
def __setstate__(self, state):
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
if "_qkv_same_embed_dim" not in state:
state["_qkv_same_embed_dim"] = True
super(MultiheadAttention, self).__setstate__(state)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
Queries are compared against key-value pairs to produce the output.
See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
See "Attention Is All You Need" for more details.
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
Binary and byte masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
Default: ``True``.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
Outputs:
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
.. note::
`batch_first` argument is ignored for unbatched inputs.
"""
is_batched = query.dim() == 3
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
why_not_fast_path = ""
if not is_batched:
why_not_fast_path = (
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif (
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
):
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
elif self.training:
why_not_fast_path = "training is enabled"
elif not self.batch_first:
why_not_fast_path = "batch_first was not True"
elif self.bias_k is not None:
why_not_fast_path = "self.bias_k was not None"
elif self.bias_v is not None:
why_not_fast_path = "self.bias_v was not None"
elif self.dropout:
why_not_fast_path = f"dropout was {self.dropout}, required zero"
elif self.add_zero_attn:
why_not_fast_path = "add_zero_attn was enabled"
elif not self._qkv_same_embed_dim:
why_not_fast_path = "_qkv_same_embed_dim was not True"
elif attn_mask is not None:
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = (
"key_padding_mask is not supported with NestedTensor input"
)
elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled"
if not why_not_fast_path:
tensor_args = (
query,
key,
value,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
)
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all(
[
(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args
]
):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]
):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad"
)
if not why_not_fast_path:
return torch._native_multi_head_attention(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
key_padding_mask if key_padding_mask is not None else attn_mask,
need_weights,
average_attn_weights,
1
if key_padding_mask is not None
else 0
if attn_mask is not None
else None,
)
any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, (
"MultiheadAttention does not support NestedTensor outside of its fast path. "
+ f"The fast path was not hit because {why_not_fast_path}"
)
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,
)
else:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,
)
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights
class LayerNorm(nn.Module):
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
def __init__(
self,
normalized_shape: _shape_t,
eps: float = 1e-5,
elementwise_affine: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (
F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps,
),
embedding,
)
assert embedding is None
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
class AdaptiveLayerNorm(nn.Module):
r"""Adaptive Layer Normalization"""
def __init__(self, d_model, norm) -> None:
super(AdaptiveLayerNorm, self).__init__()
self.project_layer = nn.Linear(d_model, 2 * d_model)
self.norm = norm
self.d_model = d_model
self.eps = self.norm.eps
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1,
)
return (weight * self.norm(input) + bias, embedding)
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1,
)
return weight * self.norm(input) + bias
class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
activation = _get_activation_fn(activation)
elif isinstance(activation, partial):
activation = activation(d_model)
# elif activation == BalancedDoubleSwish:
# activation = BalancedDoubleSwish(d_model)
# # We can't test self.activation in forward() in TorchScript,
# # so stash some information about it instead.
# if activation is F.relu or isinstance(activation, torch.nn.ReLU):
# self.activation_relu_or_gelu = 1
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
# self.activation_relu_or_gelu = 2
# else:
# self.activation_relu_or_gelu = 0
self.activation = activation
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
# if layer_norm_cls == IdentityNorm:
# norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
# else:
if True:
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
else:
self.norm1 = norm1
self.norm2 = norm2
def __setstate__(self, state):
super(TransformerEncoderLayer, self).__setstate__(state)
if not hasattr(self, "activation"):
self.activation = F.relu
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
x, stage_embedding = src, None
is_src_tuple = False
if isinstance(src, tuple):
x, stage_embedding = src
is_src_tuple = True
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
if self.norm_first:
x = x + self._sa_block(
self.norm1(x, stage_embedding),
src_mask,
src_key_padding_mask,
)
x = x + self._ff_block(self.norm2(x, stage_embedding))
else:
x = self.norm1(
x + self._sa_block(x, src_mask, src_key_padding_mask),
stage_embedding,
)
x = self.norm2(x + self._ff_block(x), stage_embedding)
if is_src_tuple:
return (x, stage_embedding)
return x
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
) -> Tensor:
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
Args:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
enable_nested_tensor: if True, input will automatically convert to nested tensor
(and convert back on output). This will improve the overall performance of
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
Examples::
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
return_layer_states: bool = False,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
return_layer_states: return layers' state (optional).
Shape:
see the docs in Transformer class.
"""
if return_layer_states:
layer_states = [] # layers' output
output = src
for mod in self.layers:
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
)
layer_states.append(output[0])
if self.norm is not None:
output = self.norm(output)
return layer_states, output
output = src
for mod in self.layers:
output = mod(
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
)
if self.norm is not None:
output = self.norm(output)
return output
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
class VALLE(nn.Module):
"""It implements https://arxiv.org/abs/2301.02111
"Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
"""
def __init__(
self,
d_model: int,
nhead: int,
num_layers: int,
norm_first: bool = True,
add_prenet: bool = False,
decoder_cls=TransformerEncoder,
decoder_layer_cls=TransformerEncoderLayer,
prefix_mode: int = 0,
share_embedding: bool = True,
nar_scale_factor: float = 1.0,
prepend_bos: bool = False,
num_quantizers: int = 8,
**kwargs,
):
"""
Args:
d_model:
The number of expected features in the input (required).
nhead:
The number of heads in the multiheadattention models (required).
num_layers:
The number of sub-decoder-layers in the decoder (required).
"""
super().__init__()
nar_d_model = int(d_model * nar_scale_factor)
self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS)
# ID NUM_AUDIO_TOKENS -> PAD
# ID NUM_AUDIO_TOKENS + 1 -> BOS
self.ar_audio_prepend_bos = prepend_bos
self.ar_audio_embedding = TokenEmbedding(
d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos)
)
# PreNet
if add_prenet:
self.ar_text_prenet = nn.Sequential(
Transpose(),
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(d_model),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(d_model),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(d_model),
nn.ReLU(),
nn.Dropout(0.5),
Transpose(),
nn.Linear(d_model, d_model),
)
self.ar_audio_prenet = nn.Sequential(
nn.Linear(d_model, 256),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(256, 256),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(256, d_model),
)
else:
self.ar_text_prenet = nn.Identity()
self.ar_audio_prenet = nn.Identity()
self.ar_text_position = SinePositionalEmbedding(
d_model,
dropout=0.1,
scale=False,
alpha=True,
)
self.ar_audio_position = SinePositionalEmbedding(
d_model,
dropout=0.1,
scale=False,
alpha=True,
)
self.ar_decoder = decoder_cls(
decoder_layer_cls(
d_model,
nhead,
dim_feedforward=d_model * 4,
dropout=0.1,
batch_first=True,
norm_first=norm_first,
),
num_layers=num_layers,
norm=LayerNorm(d_model) if norm_first else None,
)
self.ar_predict_layer = nn.Linear(d_model, NUM_AUDIO_TOKENS + 1, bias=False)
self.ar_accuracy_metric = MulticlassAccuracy(
NUM_AUDIO_TOKENS + 1,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=NUM_AUDIO_TOKENS,
)
self.rng = random.Random(0)
self.num_heads = nhead
self.prefix_mode = prefix_mode
self.num_quantizers = num_quantizers
assert num_quantizers >= 1
if num_quantizers > 1:
self.nar_audio_embeddings = nn.ModuleList(
[TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)]
+ [
TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS)
for i in range(num_quantizers - 1)
]
) # W_a
# PreNet
if add_prenet:
self.nar_text_prenet = nn.Sequential(
Transpose(),
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(nar_d_model),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(nar_d_model),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(nar_d_model),
nn.ReLU(),
nn.Dropout(0.5),
Transpose(),
nn.Linear(nar_d_model, nar_d_model),
)
self.nar_audio_prenet = nn.Sequential(
nn.Linear(nar_d_model, 256),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(256, 256),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(256, nar_d_model),
)
else:
self.nar_text_prenet = nn.Identity()
self.nar_audio_prenet = nn.Identity()
self.nar_text_position = SinePositionalEmbedding(
nar_d_model,
dropout=0.0,
scale=False,
alpha=False,
)
self.nar_audio_position = SinePositionalEmbedding(
nar_d_model,
dropout=0.1,
scale=False,
alpha=False,
)
self.nar_decoder = decoder_cls(
decoder_layer_cls(
nar_d_model,
int(nhead * nar_scale_factor),
dim_feedforward=nar_d_model * 4,
dropout=0.1,
batch_first=True,
norm_first=norm_first,
adaptive_layer_norm=True,
),
num_layers=int(num_layers * nar_scale_factor),
norm=AdaptiveLayerNorm(nar_d_model, norm=nn.LayerNorm(nar_d_model))
if norm_first
else None,
)
self.nar_predict_layers = nn.ModuleList(
[
nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False)
for i in range(num_quantizers - 1)
]
)
self.nar_stage_embeddings = nn.ModuleList(
[TokenEmbedding(nar_d_model, 1) for i in range(num_quantizers - 1)]
)
if share_embedding:
# We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa
# NOTE(Feiteng): In the experiment, this undermines accuracy
# self.ar_predict_layer.weight = self.ar_audio_embedding.weight
# We also share the parameters of the acoustic embedding layer and the output prediction layer,
# which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
for j in range(0, num_quantizers - 2):
self.nar_predict_layers[j].weight = self.nar_audio_embeddings[
j + 2
].weight
self.nar_accuracy_metric = MulticlassAccuracy(
NUM_AUDIO_TOKENS + 1,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=NUM_AUDIO_TOKENS,
)
def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
assert stage > 0
if stage == 1:
for name, param in self.named_parameters():
if name.startswith("ar_"):
print(f" AR parameter: {name}")
yield param
if stage == 2:
for name, param in self.named_parameters():
if name.startswith("nar_"):
print(f"NAR parameter: {name}")
yield param
def stage_named_parameters(
self, stage: int = 1
) -> Iterator[Tuple[str, nn.Parameter]]:
assert stage > 0
if stage == 1:
for pair in self.named_parameters():
if pair[0].startswith("ar_"):
yield pair
if stage == 2:
for pair in self.named_parameters():
if pair[0].startswith("nar_"):
yield pair
def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1
)
# inputs, targets
if self.ar_audio_prepend_bos:
return (
F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1),
targets,
)
return targets[:, :-1], targets[:, 1:]
def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes):
# 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
# from the same utterance.
# We implement this differently.
if self.prefix_mode == 0:
# no prefix
prefix_len = 0
y_emb = self.nar_audio_embeddings[0](y)
for j in range(1, nar_stage):
# Formula (4) (5)
y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
elif self.prefix_mode == 1:
# prefix at begining
int_low = (0.25 * y_lens.min()).type(torch.int64).item()
prefix_len = torch.randint(int_low, int_low * 2, size=()).item()
prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
for j in range(1, self.num_quantizers):
y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j])
if j < nar_stage:
y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j])
y_emb = torch.concat([y_prompts, y_emb], axis=1)
elif self.prefix_mode in [2, 4]:
if self.prefix_mode == 2:
# random prefix
prefix_len = min(225, int(0.25 * y_lens.min().item()))
y_prompts_codes = []
for b in range(codes.shape[0]):
start = self.rng.randint(0, y_lens[b].item() - prefix_len)
y_prompts_codes.append(
torch.clone(codes[b, start : start + prefix_len])
)
codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS
y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
else:
prefix_len = y_prompts_codes.shape[1]
y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
y_emb = self.nar_audio_embeddings[0](y)
for j in range(1, self.num_quantizers):
y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j])
if j < nar_stage:
y_emb += self.nar_audio_embeddings[j](codes[..., j])
y_emb = torch.concat([y_prompts, y_emb], axis=1)
else:
raise ValueError
return y_emb, prefix_len
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: Union[torch.Tensor, PromptedFeatures],
y_lens: Union[torch.Tensor, PromptedFeatures],
reduction: str = "sum",
train_stage: int = 0,
**kwargs,
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
"""
Args:
x:
A 2-D tensor of shape (N, S).
x_lens:
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
before padding.
y:
A 3-D tensor of shape (N, T, 8).
y_lens:
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
before padding.
train_stage:
0: AR & NAR modules, 1: AR modules, 2: NAR modules
Returns:
Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
"""
assert x.ndim == 2, x.shape
assert x_lens.ndim == 1, x_lens.shape
y_prompts_codes = None
if isinstance(y, PromptedFeatures):
y_prompts_codes, y = y.data
prompts_len, y_lens = y_lens.data
assert prompts_len.min() == prompts_len.max()
assert self.prefix_mode == 4
y_prompts_codes = y_prompts_codes.type(torch.int64)
assert y.ndim == 3, y.shape
assert y_lens.ndim == 1, y_lens.shape
# NOTE: x has been padded in TextTokenCollater
x_mask = make_pad_mask(x_lens).to(x.device)
y_mask = make_pad_mask(y_lens).to(y.device)
y_mask_int = y_mask.type(torch.int64)
text = x
codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1))
y, targets = self.pad_y_eos(codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS)
x_len = x_lens.max()
metrics = {}
total_loss = 0.0
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
if self.ar_audio_prepend_bos:
ar_xy_padding_mask = torch.concat(
[x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1
)
else:
ar_xy_padding_mask = xy_padding_mask
# AR Decoder
if train_stage in [0, 1]:
x = self.ar_text_embedding(text)
x = self.ar_text_prenet(x)
x = self.ar_text_position(x)
y_len = y_lens.max() + int(self.ar_audio_prepend_bos)
x_attn_mask = F.pad(
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
diagonal=1,
),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
# merge key padding and attention masks
bsz, src_len = x.shape[0], x_len + y_len
_xy_padding_mask = (
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, self.num_heads, -1, -1)
.reshape(bsz * self.num_heads, 1, src_len)
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
xy_attn_mask = new_attn_mask
y_emb = self.ar_audio_embedding(y)
y_emb = self.ar_audio_prenet(y_emb)
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.ar_decoder(
(xy_pos, None),
mask=xy_attn_mask,
# src_key_padding_mask=xy_padding_mask,
# is_causal=True,
)
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
# loss
total_loss = F.cross_entropy(logits, targets, reduction=reduction)
metrics["ArTop10Accuracy"] = self.ar_accuracy_metric(
logits.detach(), targets
).item() * y_lens.sum().type(torch.float32)
if self.num_quantizers == 1:
return ((x, codes), total_loss, metrics)
# Non-AR Decoders
if self.ar_audio_prepend_bos:
y = y[:, 1:]
if train_stage in [0, 2]:
num_nar_layers = self.num_quantizers - 1
nar_stage = self.rng.choices(
[_k for _k in range(1, self.num_quantizers)],
weights=[1.0 / num_nar_layers] * num_nar_layers,
k=1,
)[0]
x = self.nar_text_embedding(text)
x = self.nar_text_prenet(x)
x = self.nar_text_position(x)
y_emb, prefix_len = self._prepare_prompts(
y, y_lens, codes, nar_stage, y_prompts_codes
)
y_len = y_lens.max()
targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int
if self.prefix_mode in [2, 4]:
xy_padding_mask = torch.concat(
[
x_mask,
F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False),
],
dim=1,
)
elif self.prefix_mode == 1:
targets = targets[:, prefix_len:]
y_pos = self.nar_audio_prenet(y_emb)
y_pos = self.nar_audio_position(y_pos)
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.nar_decoder(
(xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight),
src_key_padding_mask=xy_padding_mask,
# is_causal=False,
)
xy_dec = xy_dec[:, x_lens.max() + prefix_len :]
if self.prefix_mode == 4:
prefix_len = 0 # reset for Top10Accuracy metric
logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1)
# loss
total_length = (y_lens).sum().type(torch.float32)
total_loss += F.cross_entropy(
logits,
targets,
ignore_index=NUM_AUDIO_TOKENS,
reduction=reduction,
) * (total_length / (total_length - prefix_len * x.shape[0]))
metrics["NarTop10Accuracy"] = (
self.nar_accuracy_metric(
F.pad(
logits.detach(),
(0, 0, 0, 1, 0, 0),
value=logits.min().cpu().item(),
),
targets,
).item()
* total_length
)
if train_stage == 0:
total_loss = total_loss / 2.0
return ((x, codes), total_loss, metrics)
def inference(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: torch.Tensor,
enroll_x_lens: torch.Tensor,
top_k: int = -100,
temperature: float = 1.0,
top_p: float = 1.0,
ras: bool = False,
) -> torch.Tensor:
"""
Args:
x:
A 2-D tensor of shape (1, S).
x_lens:
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
before padding.
y:
A 3-D tensor of shape (1, T, 8).
top_k: (`optional`) int
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
temperature: (`optional`) float
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
ras: (`optional`) bool
Whether to use repetition-aware sampling. Default to False.
Returns:
Return the predicted audio code matrix.
"""
assert x.ndim == 2, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.ndim == 3, y.shape
assert y.shape[0] == 1, y.shape
assert torch.all(x_lens > 0)
# NOTE: x has been padded in TextTokenCollater
text = x
x = self.ar_text_embedding(text)
x = self.ar_text_prenet(x)
x = self.ar_text_position(x)
text_len = x_lens.max()
prompts = y
prefix_len = y.shape[1]
# AR Decoder
# TODO: Managing decoder steps avoid repetitive computation
y = prompts[..., 0]
if self.ar_audio_prepend_bos:
y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)
x_len = x_lens.max()
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
while True:
y_emb = self.ar_audio_embedding(y)
y_emb = self.ar_audio_prenet(y_emb)
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
y_len = y.shape[1]
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len),
value=True,
)
y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
xy_dec, _ = self.ar_decoder(
(xy_pos, None),
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling(
logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_aware_sampling=ras,
preceding_tokens=y,
)
if (
torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS
or samples[0, 0] == NUM_AUDIO_TOKENS
or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
):
if prompts.shape[1] == y.shape[1]:
raise SyntaxError("well trained model shouldn't reach here.")
break
y = torch.concat([y, samples], dim=1)
codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
if self.num_quantizers == 1:
return torch.stack(codes, dim=-1)
# Non-AR Decoders
y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :])
if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
enrolled_len = enroll_x_lens.max().item()
# SOS + Synthesis Text + EOS
text = torch.concat(
[
text[:, :1],
text[:, enrolled_len - 1 :],
],
dim=1,
)
text_len = text_len - (enrolled_len - 2)
assert text.shape[0] == 1
x = self.nar_text_embedding(text)
x = self.nar_text_prenet(x)
x = self.nar_text_position(x)
if self.prefix_mode == 0:
for i, (predict_layer, embedding_layer) in enumerate(
zip(
self.nar_predict_layers,
self.nar_audio_embeddings[1:],
)
):
y_pos = self.nar_audio_prenet(y_emb)
y_pos = self.nar_audio_position(y_pos)
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.nar_decoder(
(xy_pos, self.nar_stage_embeddings[i].weight)
)
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
samples = torch.argmax(logits, dim=-1)
codes.append(samples)
if i < self.num_quantizers - 2:
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
y_emb[:, prefix_len:] += embedding_layer(samples)
else:
for j in range(1, self.num_quantizers):
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
for i, (predict_layer, embedding_layer) in enumerate(
zip(
self.nar_predict_layers,
self.nar_audio_embeddings[1:],
)
):
y_pos = self.nar_audio_prenet(y_emb)
y_pos = self.nar_audio_position(y_pos)
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.nar_decoder(
(xy_pos, self.nar_stage_embeddings[i].weight)
)
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
samples = torch.argmax(logits, dim=-1)
codes.append(samples)
if i < self.num_quantizers - 2:
y_emb[:, prefix_len:] += embedding_layer(samples)
assert len(codes) == self.num_quantizers
return torch.stack(codes, dim=-1)
def visualize(
self,
predicts: Tuple[torch.Tensor],
batch: Dict[str, Union[List, torch.Tensor]],
tokenizer: TextTokenCollater,
output_dir: str,
limit: int = 4,
) -> None:
audio_features = batch["features"].to("cpu").detach().numpy()
audio_features_lens = batch["features_lens"].to("cpu").detach().numpy()
tokens = batch["tokens"]
text_tokens, text_tokens_lens = tokenizer(tokens)
assert text_tokens.ndim == 2
texts = batch["text"]
utt_ids = [cut.id for cut in batch["cut"]]
encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
decoder_outputs = predicts[1]
if isinstance(decoder_outputs, list):
decoder_outputs = decoder_outputs[-1]
decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
vmin, vmax = 0, 1024 # Encodec
num_figures = 3
for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
_ = plt.figure(figsize=(14, 8 * num_figures))
S = text_tokens_lens[b]
T = audio_features_lens[b]
# encoder
plt.subplot(num_figures, 1, 1)
plt.title(f"Text: {text}")
plt.imshow(
X=np.transpose(encoder_outputs[b]),
cmap=plt.get_cmap("jet"),
aspect="auto",
interpolation="nearest",
)
plt.gca().invert_yaxis()
plt.axvline(x=S - 0.4, linewidth=2, color="r")
plt.xlabel("Encoder Output")
plt.colorbar()
# decoder
plt.subplot(num_figures, 1, 2)
plt.imshow(
X=np.transpose(decoder_outputs[b]),
cmap=plt.get_cmap("jet"),
aspect="auto",
interpolation="nearest",
vmin=vmin,
vmax=vmax,
)
plt.gca().invert_yaxis()
plt.axvline(x=T - 0.4, linewidth=2, color="r")
plt.xlabel("Decoder Output")
plt.colorbar()
# target
plt.subplot(num_figures, 1, 3)
plt.imshow(
X=np.transpose(audio_features[b]),
cmap=plt.get_cmap("jet"),
aspect="auto",
interpolation="nearest",
vmin=vmin,
vmax=vmax,
)
plt.gca().invert_yaxis()
plt.axvline(x=T - 0.4, linewidth=2, color="r")
plt.xlabel("Decoder Target")
plt.colorbar()
plt.savefig(f"{output_dir}/{utt_id}.png")
plt.close()
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
def topk_sampling(
logits,
top_k=10,
top_p=1.0,
temperature=1.0,
repetition_aware_sampling=False,
preceding_tokens=None,
):
if temperature != 1.0:
logits = logits / temperature
# Top-p/top-k filtering
logits_filtered = top_k_top_p_filtering(
logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2
)
# Sample
probs = F.softmax(logits_filtered, dim=-1)
tokens = torch.multinomial(probs, num_samples=1)
if repetition_aware_sampling:
window_size = 10
threshold = 0.1
# we first generate the target code ct
# by nucleus sampling with a pre-defined top-p value v. Then, we
# calculate the repetition ratio r of token ct
# in the preceding code sequence with a window size K.
# If the ratio r exceeds a pre-defined repetition threshold ratio tn, we replace the target code ct
# by
# random sampling from p(ct
# |x, c<t·G,0; θAR). make sure the token is not repeated.
# https://arxiv.org/abs/2406.05370
# y: B, T
# token: B, 1
assert preceding_tokens is not None
if preceding_tokens.shape[1] > window_size:
preceding_tokens = preceding_tokens[:, -window_size:]
if preceding_tokens.shape[1] > 0:
for i, item in enumerate(preceding_tokens):
# check if the repeat ratio exceeds the threshold
if (item == tokens[i]).sum() / window_size > threshold:
# replace the target code ct by random sampling
probs = F.softmax(logits[i], dim=-1)
token_new = torch.multinomial(probs, num_samples=1)
tokens[i] = token_new
return tokens