2021-12-23 18:33:39 +08:00

188 lines
7.6 KiB
Python

import torch
from torch import nn
from torch import Tensor
from typing import Tuple
class JointCodebookPredictor(nn.Module):
"""
This module predicts a group of codebook indexes from a vector. The idea is that
you have a number of codebooks (probably jointly trained), from class Quantizer,
and you want to predict the probabilities of the codebook entries based on some
predictor that you are training.
The simplest thing would be to project the vector using nn.Linear, then
reshape and use logsoftmax to normalize the probabilities within each group,
then compute the likelihood. However, this has a constraint that all the
codebooks are predicted independently of each other. This module allows you
to predict them jointly, by regressing each codebook on all previous codebooks.
This is done with a nonlinearity in which the previous codebook entries are combined
with the input predictor vector, so that the regression is not purely
linear.
Args:
predictor_dim: the number of features that we use to predict the codebook
indexes, e.g. 2048 (will depend on your model).
num_codebooks: the number of codebooks that you are predicting;
will likely be the same as the bytes_per_frame given to the
QuantizerTrainer that you used to train the Quantizer you
are predicting.
codebook_size: number of entries per codebook (often 256)
self_prediction: you can set this to false to enable prediction of
codebooks by earlier-numbered codebooks
hidden_dim: the hidden dimension per codebook (we use a 1-hidden-layer
network, with a ReLU and then batchnorm).
"""
def __init__(
self,
predictor_dim: int,
num_codebooks: int,
codebook_size: int = 256,
self_prediction: bool = True,
hidden_dim: int = 384,
):
super(JointCodebookPredictor, self).__init__()
self.num_codebooks = num_codebooks
self.codebook_size = codebook_size
self.hidden_dim = hidden_dim
self.linear1 = nn.Linear(predictor_dim, num_codebooks * hidden_dim)
if self_prediction:
linear_self_out_dim = (num_codebooks - 1) * hidden_dim
linear_self_in_dim = (num_codebooks - 1) * codebook_size
self.linear_self = nn.Parameter(
torch.randn(linear_self_out_dim, linear_self_in_dim)
* (linear_self_in_dim ** -0.5)
)
# num_codebooks == 3 and hidden_dim == 2 and codebook_size == 2,
# the expression below has the value:
# tensor([[ True, True, False, False],
# [ True, True, False, False],
# [ True, True, True, True],
# [ True, True, True, True]])
self.register_buffer(
"linear_self_mask",
(
(torch.arange(linear_self_out_dim) // hidden_dim).unsqueeze(
1
)
>= (
torch.arange(linear_self_in_dim) // codebook_size
).unsqueeze(0)
),
)
else:
self.register_parameter("linear_self", None)
self.register_buffer("linear_self_mask", None)
self.norm = nn.BatchNorm1d(num_codebooks * hidden_dim)
self.linear2 = nn.Parameter(
torch.randn(num_codebooks, codebook_size, hidden_dim)
* (hidden_dim ** -0.5)
)
self.bias2 = nn.Parameter(torch.zeros(num_codebooks, codebook_size))
def forward(
self, predictor: Tensor, codebook_indexes: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Forward function.
Args:
predictor: a Tensor of some real type, with shape (*, predictor_dim).
codebook_indexes: a Tensor of integers, of shape (*, num_codebooks),
where the '*' should be the same as for `predictor`. It will be
converted to type torch.int64. Should contain indexes of codebook
entries, in {0..codebook_size-1},
or negative values which will be interpreted as "no codebook index here"
(e.g. due to padding); we assume that each frame will either have
all-negative or all-nonnegative indexes, meaning that (codebook_indexes >= 0)
should not vary as you change the last index into it.
Returns: total_logprob, total_count, where:
total_logprob: a scalar Tensor, containing the total log-probability of all
the nonnegative codebook indexes,
total_count: a scalar Tensor containing the total count of nonzero frames,
satisfying total_count <= codebook_indexes.numel() / num_groups
"""
codebook_indexes = codebook_indexes.to(torch.int64)
# import pdb; pdb.set_trace()
assert list(predictor.shape[:-1]) == list(codebook_indexes.shape[:-1])
assert codebook_indexes.shape[-1] == self.num_codebooks
tot_codebook_dim = self.num_codebooks * self.codebook_size
common_shape = list(predictor.shape[:-1])
codebook_one_hot = torch.zeros(
*common_shape,
self.num_codebooks,
self.codebook_size,
device=predictor.device
)
codebook_mask = (codebook_indexes >= 0).unsqueeze(
-1
) # (*, num_codebooks, 1)
codebook_indexes_floor = torch.clamp(codebook_indexes, min=0).unsqueeze(
-1
) # (*, num_codebooks, 1)
codebook_one_hot.scatter_(
dim=-1,
index=codebook_indexes_floor,
src=codebook_mask.to(torch.float32),
)
codebook_one_hot = codebook_one_hot.reshape(
*common_shape, tot_codebook_dim
) # (*, tot_codebook_dim)
hidden = self.linear1(predictor)
if self.linear_self is not None:
codebook_one_hot_part = torch.narrow(
codebook_one_hot, -1, 0, tot_codebook_dim - self.codebook_size
)
self_predictor = torch.matmul(
codebook_one_hot_part,
(self.linear_self * self.linear_self_mask).transpose(0, 1),
)
# add the 'self_predictor' term to all but the 1st
# block of "hidden".
hidden_part = torch.narrow(
hidden,
-1,
self.hidden_dim,
self.hidden_dim * (self.num_codebooks - 1),
)
hidden_part += self_predictor
hidden = nn.functional.relu(hidden)
hidden = hidden.reshape(-1, self.hidden_dim * self.num_codebooks)
hidden = self.norm(hidden)
hidden = hidden.reshape(
*common_shape, self.num_codebooks, self.hidden_dim
) # (*, num_codebooks, hidden_dim)
logprobs = (
torch.matmul(
hidden.unsqueeze(-2), self.linear2.transpose(1, 2)
).squeeze(-2)
+ self.bias2
)
# logprobs: (*, num_codebooks, codebook_size)
logprobs = logprobs.log_softmax(dim=-1)
logprobs = logprobs.reshape(
*common_shape, self.num_codebooks * self.codebook_size
)
tot_logprob = torch.dot(
logprobs.reshape(-1), codebook_one_hot.reshape(-1)
)
assert tot_logprob <= 0.0
# the select() part is to select only the mask for one of the codebooks (they should
# all be the same), as we want the total number of frames.
tot_count = codebook_mask.select(dim=-2, index=0).sum()
return (tot_logprob, tot_count)