mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
1732 lines
65 KiB
Python
1732 lines
65 KiB
Python
# Copyright 2023 (authors: Feiteng Li)
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import copy
|
||
import math
|
||
import numbers
|
||
import random
|
||
from functools import partial
|
||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
from tokenizer import TextTokenCollater
|
||
from torch import Tensor
|
||
from torch.nn import Linear, Module
|
||
from torch.nn import functional as F
|
||
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||
from torch.nn.parameter import Parameter
|
||
from torchmetrics.classification import MulticlassAccuracy
|
||
|
||
from icefall.utils import make_pad_mask
|
||
|
||
NUM_TEXT_TOKENS = 5000
|
||
NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
|
||
|
||
|
||
class PromptedFeatures:
|
||
def __init__(self, prompts, features):
|
||
self.prompts = prompts
|
||
self.features = features
|
||
|
||
def to(self, device):
|
||
return PromptedFeatures(self.prompts.to(device), self.features.to(device))
|
||
|
||
def sum(self):
|
||
return self.features.sum()
|
||
|
||
@property
|
||
def ndim(self):
|
||
return self.features.ndim
|
||
|
||
@property
|
||
def data(self):
|
||
return (self.prompts, self.features)
|
||
|
||
|
||
class TokenEmbedding(nn.Module):
|
||
def __init__(
|
||
self,
|
||
dim_model: int,
|
||
vocab_size: int,
|
||
dropout: float = 0.0,
|
||
):
|
||
super().__init__()
|
||
|
||
self.vocab_size = vocab_size
|
||
self.dim_model = dim_model
|
||
|
||
self.dropout = torch.nn.Dropout(p=dropout)
|
||
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
|
||
|
||
@property
|
||
def weight(self) -> torch.Tensor:
|
||
return self.word_embeddings.weight
|
||
|
||
def embedding(self, index: int) -> torch.Tensor:
|
||
return self.word_embeddings.weight[index : index + 1]
|
||
|
||
def forward(self, x: torch.Tensor):
|
||
X = self.word_embeddings(x)
|
||
X = self.dropout(X)
|
||
|
||
return X
|
||
|
||
|
||
class SinePositionalEmbedding(nn.Module):
|
||
def __init__(
|
||
self,
|
||
dim_model: int,
|
||
dropout: float = 0.0,
|
||
scale: bool = False,
|
||
alpha: bool = False,
|
||
):
|
||
super().__init__()
|
||
self.dim_model = dim_model
|
||
self.x_scale = math.sqrt(dim_model) if scale else 1.0
|
||
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
||
self.dropout = torch.nn.Dropout(p=dropout)
|
||
|
||
self.reverse = False
|
||
self.pe = None
|
||
self.extend_pe(torch.tensor(0.0).expand(1, 4000))
|
||
|
||
def extend_pe(self, x):
|
||
"""Reset the positional encodings."""
|
||
if self.pe is not None:
|
||
if self.pe.size(1) >= x.size(1):
|
||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||
return
|
||
pe = torch.zeros(x.size(1), self.dim_model)
|
||
if self.reverse:
|
||
position = torch.arange(
|
||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||
).unsqueeze(1)
|
||
else:
|
||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||
div_term = torch.exp(
|
||
torch.arange(0, self.dim_model, 2, dtype=torch.float32)
|
||
* -(math.log(10000.0) / self.dim_model)
|
||
)
|
||
pe[:, 0::2] = torch.sin(position * div_term)
|
||
pe[:, 1::2] = torch.cos(position * div_term)
|
||
pe = pe.unsqueeze(0)
|
||
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
self.extend_pe(x)
|
||
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
||
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
|
||
return self.dropout(output)
|
||
|
||
|
||
class Transpose(nn.Identity):
|
||
"""(N, T, D) -> (N, D, T)"""
|
||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||
return input.transpose(1, 2)
|
||
|
||
|
||
_shape_t = Union[int, List[int], torch.Size]
|
||
|
||
|
||
class MultiheadAttention(Module):
|
||
r"""Allows the model to jointly attend to information
|
||
from different representation subspaces as described in the paper:
|
||
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
||
|
||
Multi-Head Attention is defined as:
|
||
|
||
.. math::
|
||
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
||
|
||
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
||
|
||
``forward()`` will use a special optimized implementation if all of the following
|
||
conditions are met:
|
||
|
||
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
|
||
restriction will be loosened in the future.)
|
||
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
||
- training is disabled (using ``.eval()``)
|
||
- dropout is 0
|
||
- ``add_bias_kv`` is ``False``
|
||
- ``add_zero_attn`` is ``False``
|
||
- ``batch_first`` is ``True`` and the input is batched
|
||
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
||
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
|
||
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
||
nor ``attn_mask`` is passed
|
||
|
||
If the optimized implementation is in use, a
|
||
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
||
``query``/``key``/``value`` to represent padding more efficiently than using a
|
||
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
||
will be returned, and an additional speedup proportional to the fraction of the input
|
||
that is padding can be expected.
|
||
|
||
Args:
|
||
embed_dim: Total dimension of the model.
|
||
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
||
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
||
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
||
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
||
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
||
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
||
Default: ``False``.
|
||
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
||
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
||
batch_first: If ``True``, then the input and output tensors are provided
|
||
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
||
|
||
Examples::
|
||
|
||
>>> # xdoctest: +SKIP
|
||
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||
|
||
"""
|
||
__constants__ = ["batch_first"]
|
||
bias_k: Optional[torch.Tensor]
|
||
bias_v: Optional[torch.Tensor]
|
||
|
||
def __init__(
|
||
self,
|
||
embed_dim,
|
||
num_heads,
|
||
dropout=0.0,
|
||
bias=True,
|
||
add_bias_kv=False,
|
||
add_zero_attn=False,
|
||
kdim=None,
|
||
vdim=None,
|
||
batch_first=False,
|
||
linear1_cls=Linear,
|
||
linear2_cls=Linear,
|
||
device=None,
|
||
dtype=None,
|
||
) -> None:
|
||
factory_kwargs = {"device": device, "dtype": dtype}
|
||
super(MultiheadAttention, self).__init__()
|
||
self.embed_dim = embed_dim
|
||
self.kdim = kdim if kdim is not None else embed_dim
|
||
self.vdim = vdim if vdim is not None else embed_dim
|
||
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||
|
||
self.num_heads = num_heads
|
||
self.dropout = dropout
|
||
self.batch_first = batch_first
|
||
self.head_dim = embed_dim // num_heads
|
||
assert (
|
||
self.head_dim * num_heads == self.embed_dim
|
||
), "embed_dim must be divisible by num_heads"
|
||
|
||
if add_bias_kv:
|
||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||
else:
|
||
self.bias_k = self.bias_v = None
|
||
|
||
if linear1_cls == Linear:
|
||
if not self._qkv_same_embed_dim:
|
||
self.q_proj_weight = Parameter(
|
||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||
)
|
||
self.k_proj_weight = Parameter(
|
||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||
)
|
||
self.v_proj_weight = Parameter(
|
||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||
)
|
||
self.register_parameter("in_proj_weight", None)
|
||
else:
|
||
self.in_proj_weight = Parameter(
|
||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||
)
|
||
self.register_parameter("q_proj_weight", None)
|
||
self.register_parameter("k_proj_weight", None)
|
||
self.register_parameter("v_proj_weight", None)
|
||
|
||
if bias:
|
||
self.in_proj_bias = Parameter(
|
||
torch.empty(3 * embed_dim, **factory_kwargs)
|
||
)
|
||
else:
|
||
self.register_parameter("in_proj_bias", None)
|
||
self.out_proj = NonDynamicallyQuantizableLinear(
|
||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||
)
|
||
|
||
self._reset_parameters()
|
||
else:
|
||
if not self._qkv_same_embed_dim:
|
||
raise NotImplementedError
|
||
else:
|
||
self.in_proj_linear = linear1_cls(
|
||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
||
)
|
||
self.in_proj_weight = self.in_proj_linear.weight
|
||
|
||
self.register_parameter("q_proj_weight", None)
|
||
self.register_parameter("k_proj_weight", None)
|
||
self.register_parameter("v_proj_weight", None)
|
||
|
||
if bias:
|
||
self.in_proj_bias = self.in_proj_linear.bias
|
||
else:
|
||
self.register_parameter("in_proj_bias", None)
|
||
|
||
self.out_proj = linear2_cls(
|
||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||
)
|
||
|
||
if self.bias_k is not None:
|
||
xavier_normal_(self.bias_k)
|
||
if self.bias_v is not None:
|
||
xavier_normal_(self.bias_v)
|
||
|
||
self.add_zero_attn = add_zero_attn
|
||
|
||
def _reset_parameters(self):
|
||
if self._qkv_same_embed_dim:
|
||
xavier_uniform_(self.in_proj_weight)
|
||
else:
|
||
xavier_uniform_(self.q_proj_weight)
|
||
xavier_uniform_(self.k_proj_weight)
|
||
xavier_uniform_(self.v_proj_weight)
|
||
|
||
if self.in_proj_bias is not None:
|
||
constant_(self.in_proj_bias, 0.0)
|
||
constant_(self.out_proj.bias, 0.0)
|
||
|
||
if self.bias_k is not None:
|
||
xavier_normal_(self.bias_k)
|
||
if self.bias_v is not None:
|
||
xavier_normal_(self.bias_v)
|
||
|
||
def __setstate__(self, state):
|
||
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
||
if "_qkv_same_embed_dim" not in state:
|
||
state["_qkv_same_embed_dim"] = True
|
||
|
||
super(MultiheadAttention, self).__setstate__(state)
|
||
|
||
def forward(
|
||
self,
|
||
query: Tensor,
|
||
key: Tensor,
|
||
value: Tensor,
|
||
key_padding_mask: Optional[Tensor] = None,
|
||
need_weights: bool = True,
|
||
attn_mask: Optional[Tensor] = None,
|
||
average_attn_weights: bool = True,
|
||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||
r"""
|
||
Args:
|
||
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
||
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
||
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
||
Queries are compared against key-value pairs to produce the output.
|
||
See "Attention Is All You Need" for more details.
|
||
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
||
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
||
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
||
See "Attention Is All You Need" for more details.
|
||
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
||
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
||
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
||
See "Attention Is All You Need" for more details.
|
||
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
||
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
||
Binary and byte masks are supported.
|
||
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
||
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
||
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
||
Default: ``True``.
|
||
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
||
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
||
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
||
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
||
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
||
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
|
||
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
||
the attention weight.
|
||
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
||
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
||
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
||
|
||
Outputs:
|
||
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
||
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
||
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
||
embedding dimension ``embed_dim``.
|
||
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
||
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
||
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
||
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
||
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
||
|
||
.. note::
|
||
`batch_first` argument is ignored for unbatched inputs.
|
||
"""
|
||
is_batched = query.dim() == 3
|
||
if key_padding_mask is not None:
|
||
_kpm_dtype = key_padding_mask.dtype
|
||
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
||
key_padding_mask
|
||
):
|
||
raise AssertionError(
|
||
"only bool and floating types of key_padding_mask are supported"
|
||
)
|
||
why_not_fast_path = ""
|
||
if not is_batched:
|
||
why_not_fast_path = (
|
||
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||
)
|
||
elif query is not key or key is not value:
|
||
# When lifting this restriction, don't forget to either
|
||
# enforce that the dtypes all match or test cases where
|
||
# they don't!
|
||
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
||
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||
elif (
|
||
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
|
||
):
|
||
# this case will fail anyway, but at least they'll get a useful error message.
|
||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||
elif self.training:
|
||
why_not_fast_path = "training is enabled"
|
||
elif not self.batch_first:
|
||
why_not_fast_path = "batch_first was not True"
|
||
elif self.bias_k is not None:
|
||
why_not_fast_path = "self.bias_k was not None"
|
||
elif self.bias_v is not None:
|
||
why_not_fast_path = "self.bias_v was not None"
|
||
elif self.dropout:
|
||
why_not_fast_path = f"dropout was {self.dropout}, required zero"
|
||
elif self.add_zero_attn:
|
||
why_not_fast_path = "add_zero_attn was enabled"
|
||
elif not self._qkv_same_embed_dim:
|
||
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
||
elif attn_mask is not None:
|
||
why_not_fast_path = "attn_mask was not None"
|
||
elif query.is_nested and key_padding_mask is not None:
|
||
why_not_fast_path = (
|
||
"key_padding_mask is not supported with NestedTensor input"
|
||
)
|
||
elif self.num_heads % 2 == 1:
|
||
why_not_fast_path = "num_heads is odd"
|
||
elif torch.is_autocast_enabled():
|
||
why_not_fast_path = "autocast is enabled"
|
||
|
||
if not why_not_fast_path:
|
||
tensor_args = (
|
||
query,
|
||
key,
|
||
value,
|
||
self.in_proj_weight,
|
||
self.in_proj_bias,
|
||
self.out_proj.weight,
|
||
self.out_proj.bias,
|
||
)
|
||
# We have to use list comprehensions below because TorchScript does not support
|
||
# generator expressions.
|
||
if torch.overrides.has_torch_function(tensor_args):
|
||
why_not_fast_path = "some Tensor argument has_torch_function"
|
||
elif not all(
|
||
[
|
||
(x is None or x.is_cuda or "cpu" in str(x.device))
|
||
for x in tensor_args
|
||
]
|
||
):
|
||
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
|
||
elif torch.is_grad_enabled() and any(
|
||
[x is not None and x.requires_grad for x in tensor_args]
|
||
):
|
||
why_not_fast_path = (
|
||
"grad is enabled and at least one of query or the "
|
||
"input/output projection weights or biases requires_grad"
|
||
)
|
||
if not why_not_fast_path:
|
||
return torch._native_multi_head_attention(
|
||
query,
|
||
key,
|
||
value,
|
||
self.embed_dim,
|
||
self.num_heads,
|
||
self.in_proj_weight,
|
||
self.in_proj_bias,
|
||
self.out_proj.weight,
|
||
self.out_proj.bias,
|
||
key_padding_mask if key_padding_mask is not None else attn_mask,
|
||
need_weights,
|
||
average_attn_weights,
|
||
1
|
||
if key_padding_mask is not None
|
||
else 0
|
||
if attn_mask is not None
|
||
else None,
|
||
)
|
||
|
||
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||
assert not any_nested, (
|
||
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
||
+ f"The fast path was not hit because {why_not_fast_path}"
|
||
)
|
||
|
||
if self.batch_first and is_batched:
|
||
# make sure that the transpose op does not affect the "is" property
|
||
if key is value:
|
||
if query is key:
|
||
query = key = value = query.transpose(1, 0)
|
||
else:
|
||
query, key = [x.transpose(1, 0) for x in (query, key)]
|
||
value = key
|
||
else:
|
||
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
||
|
||
if not self._qkv_same_embed_dim:
|
||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||
query,
|
||
key,
|
||
value,
|
||
self.embed_dim,
|
||
self.num_heads,
|
||
self.in_proj_weight,
|
||
self.in_proj_bias,
|
||
self.bias_k,
|
||
self.bias_v,
|
||
self.add_zero_attn,
|
||
self.dropout,
|
||
self.out_proj.weight,
|
||
self.out_proj.bias,
|
||
training=self.training,
|
||
key_padding_mask=key_padding_mask,
|
||
need_weights=need_weights,
|
||
attn_mask=attn_mask,
|
||
use_separate_proj_weight=True,
|
||
q_proj_weight=self.q_proj_weight,
|
||
k_proj_weight=self.k_proj_weight,
|
||
v_proj_weight=self.v_proj_weight,
|
||
average_attn_weights=average_attn_weights,
|
||
)
|
||
else:
|
||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||
query,
|
||
key,
|
||
value,
|
||
self.embed_dim,
|
||
self.num_heads,
|
||
self.in_proj_weight,
|
||
self.in_proj_bias,
|
||
self.bias_k,
|
||
self.bias_v,
|
||
self.add_zero_attn,
|
||
self.dropout,
|
||
self.out_proj.weight,
|
||
self.out_proj.bias,
|
||
training=self.training,
|
||
key_padding_mask=key_padding_mask,
|
||
need_weights=need_weights,
|
||
attn_mask=attn_mask,
|
||
average_attn_weights=average_attn_weights,
|
||
)
|
||
if self.batch_first and is_batched:
|
||
return attn_output.transpose(1, 0), attn_output_weights
|
||
else:
|
||
return attn_output, attn_output_weights
|
||
|
||
|
||
class LayerNorm(nn.Module):
|
||
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
||
normalized_shape: Tuple[int, ...]
|
||
eps: float
|
||
elementwise_affine: bool
|
||
|
||
def __init__(
|
||
self,
|
||
normalized_shape: _shape_t,
|
||
eps: float = 1e-5,
|
||
elementwise_affine: bool = True,
|
||
device=None,
|
||
dtype=None,
|
||
) -> None:
|
||
factory_kwargs = {"device": device, "dtype": dtype}
|
||
super(LayerNorm, self).__init__()
|
||
if isinstance(normalized_shape, numbers.Integral):
|
||
# mypy error: incompatible types in assignment
|
||
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||
self.eps = eps
|
||
self.elementwise_affine = elementwise_affine
|
||
if self.elementwise_affine:
|
||
self.weight = nn.Parameter(
|
||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||
)
|
||
self.bias = nn.Parameter(
|
||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||
)
|
||
else:
|
||
self.register_parameter("weight", None)
|
||
self.register_parameter("bias", None)
|
||
|
||
self.reset_parameters()
|
||
|
||
def reset_parameters(self) -> None:
|
||
if self.elementwise_affine:
|
||
nn.init.ones_(self.weight)
|
||
nn.init.zeros_(self.bias)
|
||
|
||
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
||
if isinstance(input, tuple):
|
||
input, embedding = input
|
||
return (
|
||
F.layer_norm(
|
||
input,
|
||
self.normalized_shape,
|
||
self.weight,
|
||
self.bias,
|
||
self.eps,
|
||
),
|
||
embedding,
|
||
)
|
||
|
||
assert embedding is None
|
||
return F.layer_norm(
|
||
input, self.normalized_shape, self.weight, self.bias, self.eps
|
||
)
|
||
|
||
def extra_repr(self) -> str:
|
||
return (
|
||
"{normalized_shape}, eps={eps}, "
|
||
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||
)
|
||
|
||
|
||
class AdaptiveLayerNorm(nn.Module):
|
||
r"""Adaptive Layer Normalization"""
|
||
|
||
def __init__(self, d_model, norm) -> None:
|
||
super(AdaptiveLayerNorm, self).__init__()
|
||
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
||
self.norm = norm
|
||
self.d_model = d_model
|
||
self.eps = self.norm.eps
|
||
|
||
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
||
if isinstance(input, tuple):
|
||
input, embedding = input
|
||
weight, bias = torch.split(
|
||
self.project_layer(embedding),
|
||
split_size_or_sections=self.d_model,
|
||
dim=-1,
|
||
)
|
||
return (weight * self.norm(input) + bias, embedding)
|
||
|
||
weight, bias = torch.split(
|
||
self.project_layer(embedding),
|
||
split_size_or_sections=self.d_model,
|
||
dim=-1,
|
||
)
|
||
return weight * self.norm(input) + bias
|
||
|
||
|
||
class TransformerEncoderLayer(nn.Module):
|
||
__constants__ = ["batch_first", "norm_first"]
|
||
|
||
def __init__(
|
||
self,
|
||
d_model: int,
|
||
nhead: int,
|
||
dim_feedforward: int = 2048,
|
||
dropout: float = 0.1,
|
||
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
||
batch_first: bool = False,
|
||
norm_first: bool = False,
|
||
device=None,
|
||
dtype=None,
|
||
linear1_self_attention_cls: nn.Module = nn.Linear,
|
||
linear2_self_attention_cls: nn.Module = nn.Linear,
|
||
linear1_feedforward_cls: nn.Module = nn.Linear,
|
||
linear2_feedforward_cls: nn.Module = nn.Linear,
|
||
layer_norm_cls: nn.Module = LayerNorm,
|
||
layer_norm_eps: float = 1e-5,
|
||
adaptive_layer_norm=False,
|
||
) -> None:
|
||
factory_kwargs = {"device": device, "dtype": dtype}
|
||
super(TransformerEncoderLayer, self).__init__()
|
||
self.self_attn = MultiheadAttention(
|
||
d_model,
|
||
nhead,
|
||
dropout=dropout,
|
||
batch_first=batch_first,
|
||
linear1_cls=linear1_self_attention_cls,
|
||
linear2_cls=linear2_self_attention_cls,
|
||
**factory_kwargs,
|
||
)
|
||
|
||
# Implementation of Feedforward model
|
||
self.linear1 = linear1_feedforward_cls(
|
||
d_model, dim_feedforward, **factory_kwargs
|
||
)
|
||
self.dropout = nn.Dropout(dropout)
|
||
self.linear2 = linear2_feedforward_cls(
|
||
dim_feedforward, d_model, **factory_kwargs
|
||
)
|
||
|
||
self.norm_first = norm_first
|
||
self.dropout1 = nn.Dropout(dropout)
|
||
self.dropout2 = nn.Dropout(dropout)
|
||
|
||
# Legacy string support for activation function.
|
||
if isinstance(activation, str):
|
||
activation = _get_activation_fn(activation)
|
||
elif isinstance(activation, partial):
|
||
activation = activation(d_model)
|
||
# elif activation == BalancedDoubleSwish:
|
||
# activation = BalancedDoubleSwish(d_model)
|
||
|
||
# # We can't test self.activation in forward() in TorchScript,
|
||
# # so stash some information about it instead.
|
||
# if activation is F.relu or isinstance(activation, torch.nn.ReLU):
|
||
# self.activation_relu_or_gelu = 1
|
||
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
|
||
# self.activation_relu_or_gelu = 2
|
||
# else:
|
||
# self.activation_relu_or_gelu = 0
|
||
self.activation = activation
|
||
|
||
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||
# if layer_norm_cls == IdentityNorm:
|
||
# norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||
# else:
|
||
if True:
|
||
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||
|
||
if adaptive_layer_norm:
|
||
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
||
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
||
else:
|
||
self.norm1 = norm1
|
||
self.norm2 = norm2
|
||
|
||
def __setstate__(self, state):
|
||
super(TransformerEncoderLayer, self).__setstate__(state)
|
||
if not hasattr(self, "activation"):
|
||
self.activation = F.relu
|
||
|
||
def forward(
|
||
self,
|
||
src: Tensor,
|
||
src_mask: Optional[Tensor] = None,
|
||
src_key_padding_mask: Optional[Tensor] = None,
|
||
) -> Tensor:
|
||
r"""Pass the input through the encoder layer.
|
||
|
||
Args:
|
||
src: the sequence to the encoder layer (required).
|
||
src_mask: the mask for the src sequence (optional).
|
||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||
|
||
Shape:
|
||
see the docs in Transformer class.
|
||
"""
|
||
x, stage_embedding = src, None
|
||
is_src_tuple = False
|
||
if isinstance(src, tuple):
|
||
x, stage_embedding = src
|
||
is_src_tuple = True
|
||
|
||
if src_key_padding_mask is not None:
|
||
_skpm_dtype = src_key_padding_mask.dtype
|
||
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
||
src_key_padding_mask
|
||
):
|
||
raise AssertionError(
|
||
"only bool and floating types of key_padding_mask are supported"
|
||
)
|
||
|
||
if self.norm_first:
|
||
x = x + self._sa_block(
|
||
self.norm1(x, stage_embedding),
|
||
src_mask,
|
||
src_key_padding_mask,
|
||
)
|
||
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
||
else:
|
||
x = self.norm1(
|
||
x + self._sa_block(x, src_mask, src_key_padding_mask),
|
||
stage_embedding,
|
||
)
|
||
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
||
|
||
if is_src_tuple:
|
||
return (x, stage_embedding)
|
||
return x
|
||
|
||
# self-attention block
|
||
def _sa_block(
|
||
self,
|
||
x: Tensor,
|
||
attn_mask: Optional[Tensor],
|
||
key_padding_mask: Optional[Tensor],
|
||
) -> Tensor:
|
||
x = self.self_attn(
|
||
x,
|
||
x,
|
||
x,
|
||
attn_mask=attn_mask,
|
||
key_padding_mask=key_padding_mask,
|
||
need_weights=False,
|
||
)[0]
|
||
return self.dropout1(x)
|
||
|
||
# feed forward block
|
||
def _ff_block(self, x: Tensor) -> Tensor:
|
||
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||
return self.dropout2(x)
|
||
|
||
|
||
class TransformerEncoder(nn.Module):
|
||
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
|
||
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
||
|
||
Args:
|
||
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||
norm: the layer normalization component (optional).
|
||
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
||
(and convert back on output). This will improve the overall performance of
|
||
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
||
|
||
Examples::
|
||
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
|
||
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
|
||
>>> src = torch.rand(10, 32, 512)
|
||
>>> out = transformer_encoder(src)
|
||
"""
|
||
__constants__ = ["norm"]
|
||
|
||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||
super(TransformerEncoder, self).__init__()
|
||
self.layers = _get_clones(encoder_layer, num_layers)
|
||
self.num_layers = num_layers
|
||
self.norm = norm
|
||
|
||
def forward(
|
||
self,
|
||
src: Tensor,
|
||
mask: Optional[Tensor] = None,
|
||
src_key_padding_mask: Optional[Tensor] = None,
|
||
return_layer_states: bool = False,
|
||
) -> Tensor:
|
||
r"""Pass the input through the encoder layers in turn.
|
||
|
||
Args:
|
||
src: the sequence to the encoder (required).
|
||
mask: the mask for the src sequence (optional).
|
||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||
return_layer_states: return layers' state (optional).
|
||
|
||
Shape:
|
||
see the docs in Transformer class.
|
||
"""
|
||
if return_layer_states:
|
||
layer_states = [] # layers' output
|
||
output = src
|
||
for mod in self.layers:
|
||
output = mod(
|
||
output,
|
||
src_mask=mask,
|
||
src_key_padding_mask=src_key_padding_mask,
|
||
)
|
||
layer_states.append(output[0])
|
||
|
||
if self.norm is not None:
|
||
output = self.norm(output)
|
||
|
||
return layer_states, output
|
||
|
||
output = src
|
||
for mod in self.layers:
|
||
output = mod(
|
||
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
|
||
)
|
||
|
||
if self.norm is not None:
|
||
output = self.norm(output)
|
||
|
||
return output
|
||
|
||
|
||
def _get_clones(module, N):
|
||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||
|
||
|
||
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
||
if activation == "relu":
|
||
return F.relu
|
||
elif activation == "gelu":
|
||
return F.gelu
|
||
|
||
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
||
|
||
|
||
class VALLE(nn.Module):
|
||
"""It implements https://arxiv.org/abs/2301.02111
|
||
"Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
d_model: int,
|
||
nhead: int,
|
||
num_layers: int,
|
||
norm_first: bool = True,
|
||
add_prenet: bool = False,
|
||
decoder_cls=TransformerEncoder,
|
||
decoder_layer_cls=TransformerEncoderLayer,
|
||
prefix_mode: int = 0,
|
||
share_embedding: bool = True,
|
||
nar_scale_factor: float = 1.0,
|
||
prepend_bos: bool = False,
|
||
num_quantizers: int = 8,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
Args:
|
||
d_model:
|
||
The number of expected features in the input (required).
|
||
nhead:
|
||
The number of heads in the multiheadattention models (required).
|
||
num_layers:
|
||
The number of sub-decoder-layers in the decoder (required).
|
||
"""
|
||
super().__init__()
|
||
nar_d_model = int(d_model * nar_scale_factor)
|
||
|
||
self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
|
||
self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS)
|
||
|
||
# ID NUM_AUDIO_TOKENS -> PAD
|
||
# ID NUM_AUDIO_TOKENS + 1 -> BOS
|
||
self.ar_audio_prepend_bos = prepend_bos
|
||
self.ar_audio_embedding = TokenEmbedding(
|
||
d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos)
|
||
)
|
||
|
||
# PreNet
|
||
if add_prenet:
|
||
self.ar_text_prenet = nn.Sequential(
|
||
Transpose(),
|
||
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
||
nn.BatchNorm1d(d_model),
|
||
nn.ReLU(),
|
||
nn.Dropout(0.5),
|
||
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
||
nn.BatchNorm1d(d_model),
|
||
nn.ReLU(),
|
||
nn.Dropout(0.5),
|
||
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
||
nn.BatchNorm1d(d_model),
|
||
nn.ReLU(),
|
||
nn.Dropout(0.5),
|
||
Transpose(),
|
||
nn.Linear(d_model, d_model),
|
||
)
|
||
|
||
self.ar_audio_prenet = nn.Sequential(
|
||
nn.Linear(d_model, 256),
|
||
nn.ReLU(),
|
||
nn.Dropout(0.25),
|
||
nn.Linear(256, 256),
|
||
nn.ReLU(),
|
||
nn.Dropout(0.25),
|
||
nn.Linear(256, d_model),
|
||
)
|
||
else:
|
||
self.ar_text_prenet = nn.Identity()
|
||
self.ar_audio_prenet = nn.Identity()
|
||
|
||
self.ar_text_position = SinePositionalEmbedding(
|
||
d_model,
|
||
dropout=0.1,
|
||
scale=False,
|
||
alpha=True,
|
||
)
|
||
self.ar_audio_position = SinePositionalEmbedding(
|
||
d_model,
|
||
dropout=0.1,
|
||
scale=False,
|
||
alpha=True,
|
||
)
|
||
|
||
self.ar_decoder = decoder_cls(
|
||
decoder_layer_cls(
|
||
d_model,
|
||
nhead,
|
||
dim_feedforward=d_model * 4,
|
||
dropout=0.1,
|
||
batch_first=True,
|
||
norm_first=norm_first,
|
||
),
|
||
num_layers=num_layers,
|
||
norm=LayerNorm(d_model) if norm_first else None,
|
||
)
|
||
self.ar_predict_layer = nn.Linear(d_model, NUM_AUDIO_TOKENS + 1, bias=False)
|
||
|
||
self.ar_accuracy_metric = MulticlassAccuracy(
|
||
NUM_AUDIO_TOKENS + 1,
|
||
top_k=10,
|
||
average="micro",
|
||
multidim_average="global",
|
||
ignore_index=NUM_AUDIO_TOKENS,
|
||
)
|
||
|
||
self.rng = random.Random(0)
|
||
self.num_heads = nhead
|
||
self.prefix_mode = prefix_mode
|
||
self.num_quantizers = num_quantizers
|
||
|
||
assert num_quantizers >= 1
|
||
if num_quantizers > 1:
|
||
self.nar_audio_embeddings = nn.ModuleList(
|
||
[TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)]
|
||
+ [
|
||
TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS)
|
||
for i in range(num_quantizers - 1)
|
||
]
|
||
) # W_a
|
||
|
||
# PreNet
|
||
if add_prenet:
|
||
self.nar_text_prenet = nn.Sequential(
|
||
Transpose(),
|
||
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
|
||
nn.BatchNorm1d(nar_d_model),
|
||
nn.ReLU(),
|
||
nn.Dropout(0.5),
|
||
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
|
||
nn.BatchNorm1d(nar_d_model),
|
||
nn.ReLU(),
|
||
nn.Dropout(0.5),
|
||
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
|
||
nn.BatchNorm1d(nar_d_model),
|
||
nn.ReLU(),
|
||
nn.Dropout(0.5),
|
||
Transpose(),
|
||
nn.Linear(nar_d_model, nar_d_model),
|
||
)
|
||
self.nar_audio_prenet = nn.Sequential(
|
||
nn.Linear(nar_d_model, 256),
|
||
nn.ReLU(),
|
||
nn.Dropout(0.25),
|
||
nn.Linear(256, 256),
|
||
nn.ReLU(),
|
||
nn.Dropout(0.25),
|
||
nn.Linear(256, nar_d_model),
|
||
)
|
||
else:
|
||
self.nar_text_prenet = nn.Identity()
|
||
self.nar_audio_prenet = nn.Identity()
|
||
|
||
self.nar_text_position = SinePositionalEmbedding(
|
||
nar_d_model,
|
||
dropout=0.0,
|
||
scale=False,
|
||
alpha=False,
|
||
)
|
||
self.nar_audio_position = SinePositionalEmbedding(
|
||
nar_d_model,
|
||
dropout=0.1,
|
||
scale=False,
|
||
alpha=False,
|
||
)
|
||
|
||
self.nar_decoder = decoder_cls(
|
||
decoder_layer_cls(
|
||
nar_d_model,
|
||
int(nhead * nar_scale_factor),
|
||
dim_feedforward=nar_d_model * 4,
|
||
dropout=0.1,
|
||
batch_first=True,
|
||
norm_first=norm_first,
|
||
adaptive_layer_norm=True,
|
||
),
|
||
num_layers=int(num_layers * nar_scale_factor),
|
||
norm=AdaptiveLayerNorm(nar_d_model, norm=nn.LayerNorm(nar_d_model))
|
||
if norm_first
|
||
else None,
|
||
)
|
||
self.nar_predict_layers = nn.ModuleList(
|
||
[
|
||
nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False)
|
||
for i in range(num_quantizers - 1)
|
||
]
|
||
)
|
||
self.nar_stage_embeddings = nn.ModuleList(
|
||
[TokenEmbedding(nar_d_model, 1) for i in range(num_quantizers - 1)]
|
||
)
|
||
|
||
if share_embedding:
|
||
# We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa
|
||
# NOTE(Feiteng): In the experiment, this undermines accuracy
|
||
# self.ar_predict_layer.weight = self.ar_audio_embedding.weight
|
||
|
||
# We also share the parameters of the acoustic embedding layer and the output prediction layer,
|
||
# which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
|
||
for j in range(0, num_quantizers - 2):
|
||
self.nar_predict_layers[j].weight = self.nar_audio_embeddings[
|
||
j + 2
|
||
].weight
|
||
|
||
self.nar_accuracy_metric = MulticlassAccuracy(
|
||
NUM_AUDIO_TOKENS + 1,
|
||
top_k=10,
|
||
average="micro",
|
||
multidim_average="global",
|
||
ignore_index=NUM_AUDIO_TOKENS,
|
||
)
|
||
|
||
def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
|
||
assert stage > 0
|
||
if stage == 1:
|
||
for name, param in self.named_parameters():
|
||
if name.startswith("ar_"):
|
||
print(f" AR parameter: {name}")
|
||
yield param
|
||
|
||
if stage == 2:
|
||
for name, param in self.named_parameters():
|
||
if name.startswith("nar_"):
|
||
print(f"NAR parameter: {name}")
|
||
yield param
|
||
|
||
def stage_named_parameters(
|
||
self, stage: int = 1
|
||
) -> Iterator[Tuple[str, nn.Parameter]]:
|
||
assert stage > 0
|
||
if stage == 1:
|
||
for pair in self.named_parameters():
|
||
if pair[0].startswith("ar_"):
|
||
yield pair
|
||
|
||
if stage == 2:
|
||
for pair in self.named_parameters():
|
||
if pair[0].startswith("nar_"):
|
||
yield pair
|
||
|
||
def pad_y_eos(self, y, y_mask_int, eos_id):
|
||
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
|
||
y_mask_int, (0, 1), value=1
|
||
)
|
||
# inputs, targets
|
||
if self.ar_audio_prepend_bos:
|
||
return (
|
||
F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1),
|
||
targets,
|
||
)
|
||
|
||
return targets[:, :-1], targets[:, 1:]
|
||
|
||
def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes):
|
||
# 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
|
||
# from the same utterance.
|
||
# We implement this differently.
|
||
if self.prefix_mode == 0:
|
||
# no prefix
|
||
prefix_len = 0
|
||
y_emb = self.nar_audio_embeddings[0](y)
|
||
for j in range(1, nar_stage):
|
||
# Formula (4) (5)
|
||
y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
|
||
elif self.prefix_mode == 1:
|
||
# prefix at begining
|
||
int_low = (0.25 * y_lens.min()).type(torch.int64).item()
|
||
prefix_len = torch.randint(int_low, int_low * 2, size=()).item()
|
||
prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
|
||
|
||
y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
|
||
y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
|
||
for j in range(1, self.num_quantizers):
|
||
y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j])
|
||
if j < nar_stage:
|
||
y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j])
|
||
y_emb = torch.concat([y_prompts, y_emb], axis=1)
|
||
elif self.prefix_mode in [2, 4]:
|
||
if self.prefix_mode == 2:
|
||
# random prefix
|
||
prefix_len = min(225, int(0.25 * y_lens.min().item()))
|
||
|
||
y_prompts_codes = []
|
||
for b in range(codes.shape[0]):
|
||
start = self.rng.randint(0, y_lens[b].item() - prefix_len)
|
||
y_prompts_codes.append(
|
||
torch.clone(codes[b, start : start + prefix_len])
|
||
)
|
||
codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS
|
||
y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
|
||
else:
|
||
prefix_len = y_prompts_codes.shape[1]
|
||
|
||
y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
|
||
y_emb = self.nar_audio_embeddings[0](y)
|
||
for j in range(1, self.num_quantizers):
|
||
y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j])
|
||
if j < nar_stage:
|
||
y_emb += self.nar_audio_embeddings[j](codes[..., j])
|
||
y_emb = torch.concat([y_prompts, y_emb], axis=1)
|
||
else:
|
||
raise ValueError
|
||
|
||
return y_emb, prefix_len
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
x_lens: torch.Tensor,
|
||
y: Union[torch.Tensor, PromptedFeatures],
|
||
y_lens: Union[torch.Tensor, PromptedFeatures],
|
||
reduction: str = "sum",
|
||
train_stage: int = 0,
|
||
**kwargs,
|
||
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
||
"""
|
||
Args:
|
||
x:
|
||
A 2-D tensor of shape (N, S).
|
||
x_lens:
|
||
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
||
before padding.
|
||
y:
|
||
A 3-D tensor of shape (N, T, 8).
|
||
y_lens:
|
||
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
||
before padding.
|
||
train_stage:
|
||
0: AR & NAR modules, 1: AR modules, 2: NAR modules
|
||
Returns:
|
||
Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
|
||
"""
|
||
assert x.ndim == 2, x.shape
|
||
assert x_lens.ndim == 1, x_lens.shape
|
||
|
||
y_prompts_codes = None
|
||
if isinstance(y, PromptedFeatures):
|
||
y_prompts_codes, y = y.data
|
||
prompts_len, y_lens = y_lens.data
|
||
assert prompts_len.min() == prompts_len.max()
|
||
assert self.prefix_mode == 4
|
||
y_prompts_codes = y_prompts_codes.type(torch.int64)
|
||
|
||
assert y.ndim == 3, y.shape
|
||
assert y_lens.ndim == 1, y_lens.shape
|
||
|
||
# NOTE: x has been padded in TextTokenCollater
|
||
x_mask = make_pad_mask(x_lens).to(x.device)
|
||
y_mask = make_pad_mask(y_lens).to(y.device)
|
||
y_mask_int = y_mask.type(torch.int64)
|
||
|
||
text = x
|
||
codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1))
|
||
|
||
y, targets = self.pad_y_eos(codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS)
|
||
|
||
x_len = x_lens.max()
|
||
|
||
metrics = {}
|
||
total_loss = 0.0
|
||
|
||
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
||
if self.ar_audio_prepend_bos:
|
||
ar_xy_padding_mask = torch.concat(
|
||
[x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1
|
||
)
|
||
else:
|
||
ar_xy_padding_mask = xy_padding_mask
|
||
# AR Decoder
|
||
if train_stage in [0, 1]:
|
||
x = self.ar_text_embedding(text)
|
||
x = self.ar_text_prenet(x)
|
||
x = self.ar_text_position(x)
|
||
|
||
y_len = y_lens.max() + int(self.ar_audio_prepend_bos)
|
||
|
||
x_attn_mask = F.pad(
|
||
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
|
||
(0, y_len),
|
||
value=True,
|
||
)
|
||
y_attn_mask = F.pad(
|
||
torch.triu(
|
||
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
||
diagonal=1,
|
||
),
|
||
(x_len, 0),
|
||
value=False,
|
||
)
|
||
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
|
||
|
||
# merge key padding and attention masks
|
||
bsz, src_len = x.shape[0], x_len + y_len
|
||
_xy_padding_mask = (
|
||
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
|
||
.expand(-1, self.num_heads, -1, -1)
|
||
.reshape(bsz * self.num_heads, 1, src_len)
|
||
)
|
||
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
||
|
||
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
||
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
||
xy_attn_mask = new_attn_mask
|
||
|
||
y_emb = self.ar_audio_embedding(y)
|
||
y_emb = self.ar_audio_prenet(y_emb)
|
||
y_pos = self.ar_audio_position(y_emb)
|
||
|
||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||
|
||
xy_dec, _ = self.ar_decoder(
|
||
(xy_pos, None),
|
||
mask=xy_attn_mask,
|
||
# src_key_padding_mask=xy_padding_mask,
|
||
# is_causal=True,
|
||
)
|
||
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
|
||
# loss
|
||
total_loss = F.cross_entropy(logits, targets, reduction=reduction)
|
||
|
||
metrics["ArTop10Accuracy"] = self.ar_accuracy_metric(
|
||
logits.detach(), targets
|
||
).item() * y_lens.sum().type(torch.float32)
|
||
|
||
if self.num_quantizers == 1:
|
||
return ((x, codes), total_loss, metrics)
|
||
|
||
# Non-AR Decoders
|
||
if self.ar_audio_prepend_bos:
|
||
y = y[:, 1:]
|
||
if train_stage in [0, 2]:
|
||
num_nar_layers = self.num_quantizers - 1
|
||
nar_stage = self.rng.choices(
|
||
[_k for _k in range(1, self.num_quantizers)],
|
||
weights=[1.0 / num_nar_layers] * num_nar_layers,
|
||
k=1,
|
||
)[0]
|
||
|
||
x = self.nar_text_embedding(text)
|
||
x = self.nar_text_prenet(x)
|
||
x = self.nar_text_position(x)
|
||
|
||
y_emb, prefix_len = self._prepare_prompts(
|
||
y, y_lens, codes, nar_stage, y_prompts_codes
|
||
)
|
||
|
||
y_len = y_lens.max()
|
||
targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int
|
||
if self.prefix_mode in [2, 4]:
|
||
xy_padding_mask = torch.concat(
|
||
[
|
||
x_mask,
|
||
F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False),
|
||
],
|
||
dim=1,
|
||
)
|
||
elif self.prefix_mode == 1:
|
||
targets = targets[:, prefix_len:]
|
||
|
||
y_pos = self.nar_audio_prenet(y_emb)
|
||
y_pos = self.nar_audio_position(y_pos)
|
||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||
xy_dec, _ = self.nar_decoder(
|
||
(xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight),
|
||
src_key_padding_mask=xy_padding_mask,
|
||
# is_causal=False,
|
||
)
|
||
xy_dec = xy_dec[:, x_lens.max() + prefix_len :]
|
||
if self.prefix_mode == 4:
|
||
prefix_len = 0 # reset for Top10Accuracy metric
|
||
logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1)
|
||
|
||
# loss
|
||
total_length = (y_lens).sum().type(torch.float32)
|
||
total_loss += F.cross_entropy(
|
||
logits,
|
||
targets,
|
||
ignore_index=NUM_AUDIO_TOKENS,
|
||
reduction=reduction,
|
||
) * (total_length / (total_length - prefix_len * x.shape[0]))
|
||
metrics["NarTop10Accuracy"] = (
|
||
self.nar_accuracy_metric(
|
||
F.pad(
|
||
logits.detach(),
|
||
(0, 0, 0, 1, 0, 0),
|
||
value=logits.min().cpu().item(),
|
||
),
|
||
targets,
|
||
).item()
|
||
* total_length
|
||
)
|
||
|
||
if train_stage == 0:
|
||
total_loss = total_loss / 2.0
|
||
|
||
return ((x, codes), total_loss, metrics)
|
||
|
||
def inference(
|
||
self,
|
||
x: torch.Tensor,
|
||
x_lens: torch.Tensor,
|
||
y: torch.Tensor,
|
||
enroll_x_lens: torch.Tensor,
|
||
top_k: int = -100,
|
||
temperature: float = 1.0,
|
||
top_p: float = 1.0,
|
||
ras: bool = False,
|
||
) -> torch.Tensor:
|
||
"""
|
||
Args:
|
||
x:
|
||
A 2-D tensor of shape (1, S).
|
||
x_lens:
|
||
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
||
before padding.
|
||
y:
|
||
A 3-D tensor of shape (1, T, 8).
|
||
top_k: (`optional`) int
|
||
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
||
temperature: (`optional`) float
|
||
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
||
ras: (`optional`) bool
|
||
Whether to use repetition-aware sampling. Default to False.
|
||
Returns:
|
||
Return the predicted audio code matrix.
|
||
"""
|
||
assert x.ndim == 2, x.shape
|
||
assert x_lens.ndim == 1, x_lens.shape
|
||
assert y.ndim == 3, y.shape
|
||
assert y.shape[0] == 1, y.shape
|
||
|
||
assert torch.all(x_lens > 0)
|
||
|
||
# NOTE: x has been padded in TextTokenCollater
|
||
text = x
|
||
x = self.ar_text_embedding(text)
|
||
x = self.ar_text_prenet(x)
|
||
x = self.ar_text_position(x)
|
||
|
||
text_len = x_lens.max()
|
||
prompts = y
|
||
prefix_len = y.shape[1]
|
||
|
||
# AR Decoder
|
||
# TODO: Managing decoder steps avoid repetitive computation
|
||
y = prompts[..., 0]
|
||
if self.ar_audio_prepend_bos:
|
||
y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)
|
||
|
||
x_len = x_lens.max()
|
||
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
||
|
||
while True:
|
||
y_emb = self.ar_audio_embedding(y)
|
||
y_emb = self.ar_audio_prenet(y_emb)
|
||
y_pos = self.ar_audio_position(y_emb)
|
||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||
|
||
y_len = y.shape[1]
|
||
x_attn_mask_pad = F.pad(
|
||
x_attn_mask,
|
||
(0, y_len),
|
||
value=True,
|
||
)
|
||
y_attn_mask = F.pad(
|
||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||
(x_len, 0),
|
||
value=False,
|
||
)
|
||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
||
y.device
|
||
)
|
||
|
||
xy_dec, _ = self.ar_decoder(
|
||
(xy_pos, None),
|
||
mask=xy_attn_mask,
|
||
)
|
||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||
samples = topk_sampling(
|
||
logits,
|
||
top_k=top_k,
|
||
top_p=top_p,
|
||
temperature=temperature,
|
||
repetition_aware_sampling=ras,
|
||
preceding_tokens=y,
|
||
)
|
||
|
||
if (
|
||
torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS
|
||
or samples[0, 0] == NUM_AUDIO_TOKENS
|
||
or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
|
||
):
|
||
if prompts.shape[1] == y.shape[1]:
|
||
raise SyntaxError("well trained model shouldn't reach here.")
|
||
break
|
||
|
||
y = torch.concat([y, samples], dim=1)
|
||
|
||
codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
|
||
if self.num_quantizers == 1:
|
||
return torch.stack(codes, dim=-1)
|
||
|
||
# Non-AR Decoders
|
||
y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :])
|
||
|
||
if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
|
||
enrolled_len = enroll_x_lens.max().item()
|
||
# SOS + Synthesis Text + EOS
|
||
text = torch.concat(
|
||
[
|
||
text[:, :1],
|
||
text[:, enrolled_len - 1 :],
|
||
],
|
||
dim=1,
|
||
)
|
||
text_len = text_len - (enrolled_len - 2)
|
||
assert text.shape[0] == 1
|
||
|
||
x = self.nar_text_embedding(text)
|
||
x = self.nar_text_prenet(x)
|
||
x = self.nar_text_position(x)
|
||
|
||
if self.prefix_mode == 0:
|
||
for i, (predict_layer, embedding_layer) in enumerate(
|
||
zip(
|
||
self.nar_predict_layers,
|
||
self.nar_audio_embeddings[1:],
|
||
)
|
||
):
|
||
y_pos = self.nar_audio_prenet(y_emb)
|
||
y_pos = self.nar_audio_position(y_pos)
|
||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||
|
||
xy_dec, _ = self.nar_decoder(
|
||
(xy_pos, self.nar_stage_embeddings[i].weight)
|
||
)
|
||
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
||
|
||
samples = torch.argmax(logits, dim=-1)
|
||
codes.append(samples)
|
||
|
||
if i < self.num_quantizers - 2:
|
||
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
|
||
y_emb[:, prefix_len:] += embedding_layer(samples)
|
||
else:
|
||
for j in range(1, self.num_quantizers):
|
||
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
|
||
|
||
for i, (predict_layer, embedding_layer) in enumerate(
|
||
zip(
|
||
self.nar_predict_layers,
|
||
self.nar_audio_embeddings[1:],
|
||
)
|
||
):
|
||
y_pos = self.nar_audio_prenet(y_emb)
|
||
y_pos = self.nar_audio_position(y_pos)
|
||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||
|
||
xy_dec, _ = self.nar_decoder(
|
||
(xy_pos, self.nar_stage_embeddings[i].weight)
|
||
)
|
||
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
||
|
||
samples = torch.argmax(logits, dim=-1)
|
||
codes.append(samples)
|
||
|
||
if i < self.num_quantizers - 2:
|
||
y_emb[:, prefix_len:] += embedding_layer(samples)
|
||
|
||
assert len(codes) == self.num_quantizers
|
||
return torch.stack(codes, dim=-1)
|
||
|
||
def visualize(
|
||
self,
|
||
predicts: Tuple[torch.Tensor],
|
||
batch: Dict[str, Union[List, torch.Tensor]],
|
||
tokenizer: TextTokenCollater,
|
||
output_dir: str,
|
||
limit: int = 4,
|
||
) -> None:
|
||
audio_features = batch["features"].to("cpu").detach().numpy()
|
||
audio_features_lens = batch["features_lens"].to("cpu").detach().numpy()
|
||
|
||
tokens = batch["tokens"]
|
||
text_tokens, text_tokens_lens = tokenizer(tokens)
|
||
assert text_tokens.ndim == 2
|
||
|
||
texts = batch["text"]
|
||
utt_ids = [cut.id for cut in batch["cut"]]
|
||
|
||
encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
|
||
decoder_outputs = predicts[1]
|
||
if isinstance(decoder_outputs, list):
|
||
decoder_outputs = decoder_outputs[-1]
|
||
decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
|
||
|
||
vmin, vmax = 0, 1024 # Encodec
|
||
|
||
num_figures = 3
|
||
for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
|
||
_ = plt.figure(figsize=(14, 8 * num_figures))
|
||
|
||
S = text_tokens_lens[b]
|
||
T = audio_features_lens[b]
|
||
|
||
# encoder
|
||
plt.subplot(num_figures, 1, 1)
|
||
plt.title(f"Text: {text}")
|
||
plt.imshow(
|
||
X=np.transpose(encoder_outputs[b]),
|
||
cmap=plt.get_cmap("jet"),
|
||
aspect="auto",
|
||
interpolation="nearest",
|
||
)
|
||
plt.gca().invert_yaxis()
|
||
plt.axvline(x=S - 0.4, linewidth=2, color="r")
|
||
plt.xlabel("Encoder Output")
|
||
plt.colorbar()
|
||
|
||
# decoder
|
||
plt.subplot(num_figures, 1, 2)
|
||
plt.imshow(
|
||
X=np.transpose(decoder_outputs[b]),
|
||
cmap=plt.get_cmap("jet"),
|
||
aspect="auto",
|
||
interpolation="nearest",
|
||
vmin=vmin,
|
||
vmax=vmax,
|
||
)
|
||
plt.gca().invert_yaxis()
|
||
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
||
plt.xlabel("Decoder Output")
|
||
plt.colorbar()
|
||
|
||
# target
|
||
plt.subplot(num_figures, 1, 3)
|
||
plt.imshow(
|
||
X=np.transpose(audio_features[b]),
|
||
cmap=plt.get_cmap("jet"),
|
||
aspect="auto",
|
||
interpolation="nearest",
|
||
vmin=vmin,
|
||
vmax=vmax,
|
||
)
|
||
plt.gca().invert_yaxis()
|
||
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
||
plt.xlabel("Decoder Target")
|
||
plt.colorbar()
|
||
|
||
plt.savefig(f"{output_dir}/{utt_id}.png")
|
||
plt.close()
|
||
|
||
|
||
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
||
def top_k_top_p_filtering(
|
||
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
||
):
|
||
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||
Args:
|
||
logits: logits distribution shape (batch size, vocabulary size)
|
||
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||
"""
|
||
if top_k > 0:
|
||
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
||
# Remove all tokens with a probability less than the last token of the top-k
|
||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||
logits[indices_to_remove] = filter_value
|
||
|
||
if top_p < 1.0:
|
||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||
|
||
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
||
sorted_indices_to_remove = cumulative_probs > top_p
|
||
if min_tokens_to_keep > 1:
|
||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
||
# Shift the indices to the right to keep also the first token above the threshold
|
||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||
sorted_indices_to_remove[..., 0] = 0
|
||
|
||
# scatter sorted tensors to original indexing
|
||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||
1, sorted_indices, sorted_indices_to_remove
|
||
)
|
||
logits[indices_to_remove] = filter_value
|
||
return logits
|
||
|
||
|
||
def topk_sampling(
|
||
logits,
|
||
top_k=10,
|
||
top_p=1.0,
|
||
temperature=1.0,
|
||
repetition_aware_sampling=False,
|
||
preceding_tokens=None,
|
||
):
|
||
if temperature != 1.0:
|
||
logits = logits / temperature
|
||
# Top-p/top-k filtering
|
||
logits_filtered = top_k_top_p_filtering(
|
||
logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||
)
|
||
# Sample
|
||
probs = F.softmax(logits_filtered, dim=-1)
|
||
tokens = torch.multinomial(probs, num_samples=1)
|
||
|
||
if repetition_aware_sampling:
|
||
window_size = 10
|
||
threshold = 0.1
|
||
# we first generate the target code ct′
|
||
# by nucleus sampling with a pre-defined top-p value v. Then, we
|
||
# calculate the repetition ratio r of token ct′
|
||
# in the preceding code sequence with a window size K.
|
||
# If the ratio r exceeds a pre-defined repetition threshold ratio tn, we replace the target code ct′
|
||
# by
|
||
# random sampling from p(ct′
|
||
# |x, c<t·G,0; θAR). make sure the token is not repeated.
|
||
# https://arxiv.org/abs/2406.05370
|
||
# y: B, T
|
||
# token: B, 1
|
||
assert preceding_tokens is not None
|
||
if preceding_tokens.shape[1] > window_size:
|
||
preceding_tokens = preceding_tokens[:, -window_size:]
|
||
if preceding_tokens.shape[1] > 0:
|
||
for i, item in enumerate(preceding_tokens):
|
||
# check if the repeat ratio exceeds the threshold
|
||
if (item == tokens[i]).sum() / window_size > threshold:
|
||
# replace the target code ct′ by random sampling
|
||
probs = F.softmax(logits[i], dim=-1)
|
||
token_new = torch.multinomial(probs, num_samples=1)
|
||
tokens[i] = token_new
|
||
return tokens
|