mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
copy quantization files from dan's repo
This commit is contained in:
parent
3570cb738a
commit
8985440ce1
7
.github/workflows/ignore.list
vendored
7
.github/workflows/ignore.list
vendored
@ -1,2 +1,5 @@
|
|||||||
egs/librispeech/ASR/conformer_ctc/test_label_smoothing.py
|
egs/librispeech/ASR/conformer_ctc/checkpoint.py
|
||||||
egs/librispeech/ASR/conformer_ctc/test_subsampling.py
|
egs/librispeech/ASR/conformer_ctc/ckpnt_prediction.py
|
||||||
|
egs/librispeech/ASR/conformer_ctc/powerful_prediction.py
|
||||||
|
egs/librispeech/ASR/conformer_ctc/prediction.py
|
||||||
|
egs/librispeech/ASR/conformer_ctc/quantization.py
|
||||||
|
78
egs/librispeech/ASR/conformer_ctc/checkpoint.py
Normal file
78
egs/librispeech/ASR/conformer_ctc/checkpoint.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import Tuple, Callable
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, function: Callable, *args):
|
||||||
|
# `function` must return either a Tensor or a tuple of Tensors
|
||||||
|
ctx.function = function
|
||||||
|
ctx.args = [x.detach() if isinstance(x, Tensor) else x for x in args]
|
||||||
|
for i in range(len(ctx.args)):
|
||||||
|
if isinstance(args[i], Tensor) and args[i].requires_grad:
|
||||||
|
ctx.args[i].requires_grad = True
|
||||||
|
with torch.no_grad():
|
||||||
|
ans = function(*args)
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *ans_grads):
|
||||||
|
if not any([a is not None for a in ans_grads]):
|
||||||
|
return [None] * len(ctx.args)
|
||||||
|
with torch.enable_grad():
|
||||||
|
ans = ctx.function(*ctx.args)
|
||||||
|
if isinstance(ans, Tensor):
|
||||||
|
assert len(ans_grads) == 1
|
||||||
|
loss = (ans * ans_grads[0]).sum()
|
||||||
|
else:
|
||||||
|
assert len(ans_grads) == len(ans)
|
||||||
|
loss = torch.stack(
|
||||||
|
[
|
||||||
|
(a * g).sum()
|
||||||
|
for a, g in zip(ans, ans_grads)
|
||||||
|
if g is not None
|
||||||
|
]
|
||||||
|
).sum()
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
return tuple(
|
||||||
|
[None]
|
||||||
|
+ [a.grad if isinstance(a, Tensor) else None for a in ctx.args]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint(function, *args):
|
||||||
|
return CheckpointFunction.apply(function, *args)
|
||||||
|
|
||||||
|
|
||||||
|
def _test1():
|
||||||
|
x = torch.Tensor([0])
|
||||||
|
y = torch.Tensor([1])
|
||||||
|
y.requires_grad = True
|
||||||
|
l = lambda x, y, trash: torch.stack((x, y))
|
||||||
|
ans = checkpoint(l, x, y, None)
|
||||||
|
# ans = l(x, y, None)
|
||||||
|
print("ans = ", ans)
|
||||||
|
ans.sum().backward()
|
||||||
|
print("y grad = ", y.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def _test2():
|
||||||
|
x = torch.Tensor([0])
|
||||||
|
y = torch.Tensor([1])
|
||||||
|
x.requires_grad = True
|
||||||
|
l = lambda x, y, trash: torch.stack((x, y))
|
||||||
|
ans = checkpoint(l, x, y, None)
|
||||||
|
ans = checkpoint(torch.sum, ans)
|
||||||
|
# ans = l(x, y, None)
|
||||||
|
print("ans = ", ans)
|
||||||
|
ans.backward()
|
||||||
|
print("x grad = ", x.grad)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_test1()
|
||||||
|
_test2()
|
211
egs/librispeech/ASR/conformer_ctc/ckpnt_prediction.py
Normal file
211
egs/librispeech/ASR/conformer_ctc/ckpnt_prediction.py
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
from checkpoint import (
|
||||||
|
checkpoint,
|
||||||
|
) # from current directory.. could not get relative import to work..
|
||||||
|
|
||||||
|
# functional version of joint codebook loss, added so that we can more easily implement
|
||||||
|
# checkpointing to save memory.
|
||||||
|
def joint_codebook_loss(
|
||||||
|
predictor: Tensor,
|
||||||
|
codebook_indexes: Tensor,
|
||||||
|
linear1_weight: Tensor,
|
||||||
|
linear1_bias: Optional[Tensor],
|
||||||
|
codebook_embedding_weight: Tensor,
|
||||||
|
linear2_weight: Tensor,
|
||||||
|
linear2_bias: Tensor,
|
||||||
|
ignore_index: int,
|
||||||
|
reduction: str,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
predictor: predictor tensor of shape (*, predictor_channels)
|
||||||
|
codebook_indexes: codebook indexes of shape (*, num_codebooks)
|
||||||
|
linear1_weight: weight of shape (hidden_channels, predictor_channels)
|
||||||
|
linear1_bias: optional bias of shape (hidden_channels,)
|
||||||
|
codebook_embedding_weight: weight of shape ((num_codebooks - 1) * codebook_size,
|
||||||
|
hidden_channels)
|
||||||
|
linear2_weight: weight of shape (num_codebooks, codebook_size,
|
||||||
|
hidden_channels)
|
||||||
|
linear2_bias: bias of shape (num_codebooks, codebook_size)
|
||||||
|
ignore_index: index to ignore in cross entropy loss, e.g. -100
|
||||||
|
reduction: reduction in cross entropy loss, e.g. 'sum'
|
||||||
|
"""
|
||||||
|
num_codebooks = codebook_indexes.shape[-1]
|
||||||
|
predictor_channels = predictor.shape[-1]
|
||||||
|
hidden_channels = linear1_weight.shape[0]
|
||||||
|
codebook_size = codebook_embedding_weight.shape[0] // (num_codebooks - 1)
|
||||||
|
|
||||||
|
codebook_indexes = codebook_indexes.to(torch.int64)
|
||||||
|
assert list(predictor.shape[:-1]) == list(codebook_indexes.shape[:-1])
|
||||||
|
predictor = predictor.reshape(
|
||||||
|
-1, predictor.shape[-1]
|
||||||
|
) # (N, predictor_channels)
|
||||||
|
codebook_indexes = codebook_indexes.reshape(-1, codebook_indexes.shape[-1])
|
||||||
|
first_indexes = codebook_indexes[
|
||||||
|
:, :-1
|
||||||
|
] # all but last codebook indexes; (N, num_codebooks-1)
|
||||||
|
|
||||||
|
# do clamp(min=0) to avoid errors on padding (-100).. these frames will
|
||||||
|
# later be ignored in the loss, so the value can be treated as a don't-care.
|
||||||
|
first_indexes = first_indexes.clamp(min=0) + torch.arange(
|
||||||
|
0,
|
||||||
|
(num_codebooks - 1) * codebook_size,
|
||||||
|
step=codebook_size,
|
||||||
|
device=first_indexes.device,
|
||||||
|
) # (N, num_codebooks-1)
|
||||||
|
|
||||||
|
first_embeddings = torch.nn.functional.embedding(
|
||||||
|
first_indexes, codebook_embedding_weight
|
||||||
|
) * (
|
||||||
|
hidden_channels ** 0.5
|
||||||
|
) # (N, num_codebooks-1, hidden_channels)
|
||||||
|
|
||||||
|
hidden_predictor = torch.nn.functional.linear(
|
||||||
|
predictor, linear1_weight, linear1_bias
|
||||||
|
)
|
||||||
|
all_embeddings = torch.cat(
|
||||||
|
(hidden_predictor.unsqueeze(1), first_embeddings), dim=1
|
||||||
|
) # (N, num_codebooks, hidden_channels)
|
||||||
|
|
||||||
|
# after cumsum, all positions will contain a contribution from 'hidden_predictor'; and
|
||||||
|
# will also contain contributions from all *previous* codebooks. Here, "position" means
|
||||||
|
# a position in {0..num_codebooks-1}
|
||||||
|
all_embeddings = torch.cumsum(
|
||||||
|
all_embeddings, dim=1
|
||||||
|
) # (N, num_codebooks, hidden_channels)
|
||||||
|
|
||||||
|
all_embeddings = torch.nn.functional.relu(all_embeddings)
|
||||||
|
|
||||||
|
logprobs = torch.matmul(
|
||||||
|
all_embeddings.transpose(0, 1), # (num_codebooks, N, hidden_channels)
|
||||||
|
linear2_weight.transpose(
|
||||||
|
1, 2
|
||||||
|
), # (num_codebooks, hidden_channels, codebook_size)
|
||||||
|
).transpose(
|
||||||
|
0, 1
|
||||||
|
) # (N, num_codebooks, codebook_size)
|
||||||
|
logprobs += linear2_bias
|
||||||
|
logprobs = logprobs.log_softmax(dim=2) # (N, num_codebooks, codebook_size)
|
||||||
|
|
||||||
|
return torch.nn.functional.cross_entropy(
|
||||||
|
logprobs.reshape(-1, codebook_size),
|
||||||
|
codebook_indexes.reshape(-1),
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
reduction=reduction,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class JointCodebookLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
This module predicts a group of codebook indexes from a vector. The idea is that
|
||||||
|
you have a number of codebooks (probably jointly trained), from class Quantizer,
|
||||||
|
and you want to predict the probabilities of the codebook entries based on some
|
||||||
|
predictor that you are training.
|
||||||
|
The simplest thing would be to project the vector using nn.Linear, then
|
||||||
|
reshape and use logsoftmax to normalize the probabilities within each group,
|
||||||
|
then compute the likelihood. However, this has a constraint that all the
|
||||||
|
codebooks are predicted independently of each other. This module allows you
|
||||||
|
to predict them jointly, by regressing each codebook on all previous codebooks.
|
||||||
|
This is done with a nonlinearity in which the previous codebook entries are combined
|
||||||
|
with the input predictor vector, so that the regression is not purely
|
||||||
|
linear.
|
||||||
|
Args:
|
||||||
|
predictor_dim: the number of features that we use to predict the codebook
|
||||||
|
indexes, e.g. 2048 (will depend on your model).
|
||||||
|
hidden_dim: a hidden dimension in the model; should be more than
|
||||||
|
codebook_size, but may be less or more than predictor_dim.
|
||||||
|
num_codebooks: the number of codebooks that you are predicting;
|
||||||
|
will likely be the same as the bytes_per_frame given to the
|
||||||
|
QuantizerTrainer that you used to train the Quantizer you
|
||||||
|
are predicting.
|
||||||
|
codebook_size: number of entries per codebook (often 256)
|
||||||
|
self_prediction: you can set this to false to enable prediction of
|
||||||
|
codebooks by earlier-numbered codebooks
|
||||||
|
hidden_dim: the hidden dimension per codebook (we use a 1-hidden-layer
|
||||||
|
network, with a ReLU and then batchnorm).
|
||||||
|
checkpoint: if true, reduce backprop memory at the expense of doing
|
||||||
|
the computation twice.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
predictor_channels: int,
|
||||||
|
num_codebooks: int,
|
||||||
|
hidden_channels: int = 512,
|
||||||
|
codebook_size: int = 256,
|
||||||
|
reduction: str = "sum",
|
||||||
|
ignore_index: int = -100,
|
||||||
|
checkpoint: bool = True,
|
||||||
|
):
|
||||||
|
super(JointCodebookLoss, self).__init__()
|
||||||
|
|
||||||
|
assert num_codebooks > 1 # we may later handle this specially.
|
||||||
|
self.num_codebooks = num_codebooks
|
||||||
|
self.codebook_size = codebook_size
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
self.reduction = reduction
|
||||||
|
self.checkpoint = checkpoint
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(predictor_channels, hidden_channels)
|
||||||
|
|
||||||
|
# codebook_embedding is used to predict each codebook from previous
|
||||||
|
# codebooks, so it's a joint, not independent, model. we'll multiply
|
||||||
|
# this by hidden_channels ** 0.5 when we use it; this keeps the magnitude
|
||||||
|
# small allows it to train fast enough (relatively speaking).
|
||||||
|
self.codebook_embedding = nn.Embedding(
|
||||||
|
(num_codebooks - 1) * codebook_size,
|
||||||
|
hidden_channels,
|
||||||
|
_weight=torch.randn(
|
||||||
|
(num_codebooks - 1) * codebook_size, hidden_channels
|
||||||
|
)
|
||||||
|
* (hidden_channels ** -0.5),
|
||||||
|
)
|
||||||
|
self.nonlin = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.linear2_weight = nn.Parameter(
|
||||||
|
torch.randn(num_codebooks, codebook_size, hidden_channels)
|
||||||
|
* (hidden_channels ** -0.5)
|
||||||
|
)
|
||||||
|
self.linear2_bias = nn.Parameter(
|
||||||
|
torch.zeros(num_codebooks, codebook_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, predictor: Tensor, codebook_indexes: Tensor
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Forward function.
|
||||||
|
Args:
|
||||||
|
predictor: a Tensor of some real type, with shape (*, predictor_channels).
|
||||||
|
codebook_indexes: a Tensor of integers, of shape (*, num_codebooks),
|
||||||
|
where the '*' should be the same as for `predictor`. It will be
|
||||||
|
converted to type torch.int64. Should contain indexes of codebook
|
||||||
|
entries, in {0..codebook_size-1},
|
||||||
|
or negative values which will be interpreted as "no codebook index here"
|
||||||
|
(e.g. due to padding); we assume that each frame will either have
|
||||||
|
all-negative or all-nonnegative indexes, meaning that (codebook_indexes >= 0)
|
||||||
|
should not vary as you change the last index into it.
|
||||||
|
Returns:
|
||||||
|
cross_entropy_loss, will be a total negated log-probability, assuming
|
||||||
|
reduction == 'sum'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = (
|
||||||
|
predictor,
|
||||||
|
codebook_indexes,
|
||||||
|
self.linear1.weight,
|
||||||
|
self.linear1.bias,
|
||||||
|
self.codebook_embedding.weight,
|
||||||
|
self.linear2_weight,
|
||||||
|
self.linear2_bias,
|
||||||
|
self.ignore_index,
|
||||||
|
self.reduction,
|
||||||
|
)
|
||||||
|
if self.checkpoint:
|
||||||
|
return checkpoint(joint_codebook_loss, *args)
|
||||||
|
else:
|
||||||
|
return joint_codebook_loss(*args)
|
228
egs/librispeech/ASR/conformer_ctc/powerful_prediction.py
Normal file
228
egs/librispeech/ASR/conformer_ctc/powerful_prediction.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
from checkpoint import (
|
||||||
|
checkpoint,
|
||||||
|
) # from current directory.. could not get relative import to work..
|
||||||
|
|
||||||
|
# functional version of joint codebook loss, added so that we can more easily implement
|
||||||
|
# checkpointing to save memory.
|
||||||
|
def joint_codebook_loss(
|
||||||
|
predictor: Tensor,
|
||||||
|
codebook_indexes: Tensor,
|
||||||
|
linear1_weight: Tensor,
|
||||||
|
linear1_bias: Optional[Tensor],
|
||||||
|
codebook_embedding_weight: Tensor,
|
||||||
|
linear2_weight: Tensor,
|
||||||
|
linear2b_weight: Tensor,
|
||||||
|
linear2_bias: Tensor,
|
||||||
|
ignore_index: int,
|
||||||
|
reduction: str,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
predictor: predictor tensor of shape (*, predictor_channels)
|
||||||
|
codebook_indexes: codebook indexes of shape (*, num_codebooks)
|
||||||
|
linear1_weight: weight of shape (hidden_channels, predictor_channels)
|
||||||
|
linear1_bias: optional bias of shape (hidden_channels,)
|
||||||
|
codebook_embedding_weight: weight of shape ((num_codebooks - 1) * codebook_size,
|
||||||
|
hidden_channels)
|
||||||
|
linear2_weight: weight of shape (num_codebooks, codebook_size,
|
||||||
|
hidden_channels)
|
||||||
|
linear2b_weight: weight of shape (num_codebooks, codebook_size,
|
||||||
|
predictor_dim)
|
||||||
|
linear2_bias: bias of shape (num_codebooks, codebook_size)
|
||||||
|
ignore_index: index to ignore in cross entropy loss, e.g. -100
|
||||||
|
reduction: reduction in cross entropy loss, e.g. 'sum'
|
||||||
|
"""
|
||||||
|
num_codebooks = codebook_indexes.shape[-1]
|
||||||
|
predictor_channels = predictor.shape[-1]
|
||||||
|
hidden_channels = linear1_weight.shape[0]
|
||||||
|
codebook_size = codebook_embedding_weight.shape[0] // (num_codebooks - 1)
|
||||||
|
|
||||||
|
codebook_indexes = codebook_indexes.to(torch.int64)
|
||||||
|
assert list(predictor.shape[:-1]) == list(codebook_indexes.shape[:-1])
|
||||||
|
predictor = predictor.reshape(
|
||||||
|
-1, predictor.shape[-1]
|
||||||
|
) # (N, predictor_channels)
|
||||||
|
codebook_indexes = codebook_indexes.reshape(-1, codebook_indexes.shape[-1])
|
||||||
|
first_indexes = codebook_indexes[
|
||||||
|
:, :-1
|
||||||
|
] # all but last codebook indexes; (N, num_codebooks-1)
|
||||||
|
|
||||||
|
# do clamp(min=0) to avoid errors on padding (-100).. these frames will
|
||||||
|
# later be ignored in the loss, so the value can be treated as a don't-care.
|
||||||
|
first_indexes = first_indexes.clamp(min=0) + torch.arange(
|
||||||
|
0,
|
||||||
|
(num_codebooks - 1) * codebook_size,
|
||||||
|
step=codebook_size,
|
||||||
|
device=first_indexes.device,
|
||||||
|
) # (N, num_codebooks-1)
|
||||||
|
|
||||||
|
first_embeddings_scale = 0.5 * ((hidden_channels / num_codebooks) ** 0.5)
|
||||||
|
first_embeddings = (
|
||||||
|
torch.nn.functional.embedding(first_indexes, codebook_embedding_weight)
|
||||||
|
* first_embeddings_scale
|
||||||
|
) # (N, num_codebooks-1, hidden_channels)
|
||||||
|
|
||||||
|
hidden_predictor = torch.nn.functional.linear(
|
||||||
|
predictor, linear1_weight, linear1_bias
|
||||||
|
)
|
||||||
|
all_embeddings = torch.cat(
|
||||||
|
(hidden_predictor.unsqueeze(1), first_embeddings), dim=1
|
||||||
|
) # (N, num_codebooks, hidden_channels)
|
||||||
|
|
||||||
|
# after cumsum, all positions will contain a contribution from 'hidden_predictor'; and
|
||||||
|
# will also contain contributions from all *previous* codebooks. Here, "position" means
|
||||||
|
# a position in {0..num_codebooks-1}
|
||||||
|
all_embeddings = torch.cumsum(
|
||||||
|
all_embeddings, dim=1
|
||||||
|
) # (N, num_codebooks, hidden_channels)
|
||||||
|
|
||||||
|
all_embeddings = torch.nn.functional.relu(all_embeddings)
|
||||||
|
|
||||||
|
logprobs = torch.matmul(
|
||||||
|
all_embeddings.transpose(0, 1), # (num_codebooks, N, hidden_channels)
|
||||||
|
linear2_weight.transpose(
|
||||||
|
1, 2
|
||||||
|
), # (num_codebooks, hidden_channels, codebook_size)
|
||||||
|
).transpose(
|
||||||
|
0, 1
|
||||||
|
) # (N, num_codebooks, codebook_size)
|
||||||
|
|
||||||
|
logprobs += torch.matmul(
|
||||||
|
predictor, # (N, predictor_channels)
|
||||||
|
linear2b_weight.transpose(
|
||||||
|
1, 2
|
||||||
|
), # (num_codebooks, predictor_channels, codebook_size)
|
||||||
|
).transpose(
|
||||||
|
0, 1
|
||||||
|
) # (N, num_codebooks, codebook_size)
|
||||||
|
|
||||||
|
logprobs += linear2_bias
|
||||||
|
logprobs = logprobs.log_softmax(dim=2) # (N, num_codebooks, codebook_size)
|
||||||
|
|
||||||
|
return torch.nn.functional.cross_entropy(
|
||||||
|
logprobs.reshape(-1, codebook_size),
|
||||||
|
codebook_indexes.reshape(-1),
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
reduction=reduction,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Powerful_JointCodebookLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
This module predicts a group of codebook indexes from a vector. The idea is that
|
||||||
|
you have a number of codebooks (probably jointly trained), from class Quantizer,
|
||||||
|
and you want to predict the probabilities of the codebook entries based on some
|
||||||
|
predictor that you are training.
|
||||||
|
The simplest thing would be to project the vector using nn.Linear, then
|
||||||
|
reshape and use logsoftmax to normalize the probabilities within each group,
|
||||||
|
then compute the likelihood. However, this has a constraint that all the
|
||||||
|
codebooks are predicted independently of each other. This module allows you
|
||||||
|
to predict them jointly, by regressing each codebook on all previous codebooks.
|
||||||
|
This is done with a nonlinearity in which the previous codebook entries are combined
|
||||||
|
with the input predictor vector, so that the regression is not purely
|
||||||
|
linear.
|
||||||
|
Args:
|
||||||
|
predictor_dim: the number of features that we use to predict the codebook
|
||||||
|
indexes, e.g. 2048 (will depend on your model).
|
||||||
|
hidden_dim: a hidden dimension in the model; should be more than
|
||||||
|
codebook_size, but may be less or more than predictor_dim.
|
||||||
|
num_codebooks: the number of codebooks that you are predicting;
|
||||||
|
will likely be the same as the bytes_per_frame given to the
|
||||||
|
QuantizerTrainer that you used to train the Quantizer you
|
||||||
|
are predicting.
|
||||||
|
codebook_size: number of entries per codebook (often 256)
|
||||||
|
self_prediction: you can set this to false to enable prediction of
|
||||||
|
codebooks by earlier-numbered codebooks
|
||||||
|
hidden_dim: the hidden dimension per codebook (we use a 1-hidden-layer
|
||||||
|
network, with a ReLU and then batchnorm).
|
||||||
|
checkpoint: if true, reduce backprop memory at the expense of doing
|
||||||
|
the computation twice.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
predictor_channels: int,
|
||||||
|
num_codebooks: int,
|
||||||
|
hidden_channels: int = 512,
|
||||||
|
codebook_size: int = 256,
|
||||||
|
reduction: str = "sum",
|
||||||
|
ignore_index: int = -100,
|
||||||
|
checkpoint: bool = True,
|
||||||
|
):
|
||||||
|
super(Powerful_JointCodebookLoss, self).__init__()
|
||||||
|
|
||||||
|
assert num_codebooks > 1 # we may later handle this specially.
|
||||||
|
self.num_codebooks = num_codebooks
|
||||||
|
self.codebook_size = codebook_size
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
self.reduction = reduction
|
||||||
|
self.checkpoint = checkpoint
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(predictor_channels, hidden_channels)
|
||||||
|
|
||||||
|
# codebook_embedding is used to predict each codebook from previous
|
||||||
|
# codebooks, so it's a joint, not independent, model. we'll multiply
|
||||||
|
# this by hidden_channels ** 0.5 when we use it; this keeps the magnitude
|
||||||
|
# small allows it to train fast enough (relatively speaking).
|
||||||
|
self.codebook_embedding = nn.Embedding(
|
||||||
|
(num_codebooks - 1) * codebook_size,
|
||||||
|
hidden_channels,
|
||||||
|
_weight=torch.randn(
|
||||||
|
(num_codebooks - 1) * codebook_size, hidden_channels
|
||||||
|
)
|
||||||
|
* (hidden_channels ** -0.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.linear2_weight = nn.Parameter(
|
||||||
|
torch.randn(num_codebooks, codebook_size, hidden_channels)
|
||||||
|
* (hidden_channels ** -0.5)
|
||||||
|
)
|
||||||
|
self.linear2b_weight = nn.Parameter(
|
||||||
|
torch.randn(num_codebooks, codebook_size, predictor_channels)
|
||||||
|
* (predictor_channels ** -0.5)
|
||||||
|
)
|
||||||
|
self.linear2_bias = nn.Parameter(
|
||||||
|
torch.zeros(num_codebooks, codebook_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, predictor: Tensor, codebook_indexes: Tensor
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Forward function.
|
||||||
|
Args:
|
||||||
|
predictor: a Tensor of some real type, with shape (*, predictor_channels).
|
||||||
|
codebook_indexes: a Tensor of integers, of shape (*, num_codebooks),
|
||||||
|
where the '*' should be the same as for `predictor`. It will be
|
||||||
|
converted to type torch.int64. Should contain indexes of codebook
|
||||||
|
entries, in {0..codebook_size-1},
|
||||||
|
or negative values which will be interpreted as "no codebook index here"
|
||||||
|
(e.g. due to padding); we assume that each frame will either have
|
||||||
|
all-negative or all-nonnegative indexes, meaning that (codebook_indexes >= 0)
|
||||||
|
should not vary as you change the last index into it.
|
||||||
|
Returns:
|
||||||
|
cross_entropy_loss, will be a total negated log-probability, assuming
|
||||||
|
reduction == 'sum'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = (
|
||||||
|
predictor,
|
||||||
|
codebook_indexes,
|
||||||
|
self.linear1.weight,
|
||||||
|
self.linear1.bias,
|
||||||
|
self.codebook_embedding.weight,
|
||||||
|
self.linear2_weight,
|
||||||
|
self.linear2b_weight,
|
||||||
|
self.linear2_bias,
|
||||||
|
self.ignore_index,
|
||||||
|
self.reduction,
|
||||||
|
)
|
||||||
|
if self.checkpoint:
|
||||||
|
return checkpoint(joint_codebook_loss, *args)
|
||||||
|
else:
|
||||||
|
return joint_codebook_loss(*args)
|
187
egs/librispeech/ASR/conformer_ctc/prediction.py
Normal file
187
egs/librispeech/ASR/conformer_ctc/prediction.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class JointCodebookPredictor(nn.Module):
|
||||||
|
"""
|
||||||
|
This module predicts a group of codebook indexes from a vector. The idea is that
|
||||||
|
you have a number of codebooks (probably jointly trained), from class Quantizer,
|
||||||
|
and you want to predict the probabilities of the codebook entries based on some
|
||||||
|
predictor that you are training.
|
||||||
|
The simplest thing would be to project the vector using nn.Linear, then
|
||||||
|
reshape and use logsoftmax to normalize the probabilities within each group,
|
||||||
|
then compute the likelihood. However, this has a constraint that all the
|
||||||
|
codebooks are predicted independently of each other. This module allows you
|
||||||
|
to predict them jointly, by regressing each codebook on all previous codebooks.
|
||||||
|
This is done with a nonlinearity in which the previous codebook entries are combined
|
||||||
|
with the input predictor vector, so that the regression is not purely
|
||||||
|
linear.
|
||||||
|
Args:
|
||||||
|
predictor_dim: the number of features that we use to predict the codebook
|
||||||
|
indexes, e.g. 2048 (will depend on your model).
|
||||||
|
num_codebooks: the number of codebooks that you are predicting;
|
||||||
|
will likely be the same as the bytes_per_frame given to the
|
||||||
|
QuantizerTrainer that you used to train the Quantizer you
|
||||||
|
are predicting.
|
||||||
|
codebook_size: number of entries per codebook (often 256)
|
||||||
|
self_prediction: you can set this to false to enable prediction of
|
||||||
|
codebooks by earlier-numbered codebooks
|
||||||
|
hidden_dim: the hidden dimension per codebook (we use a 1-hidden-layer
|
||||||
|
network, with a ReLU and then batchnorm).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
predictor_dim: int,
|
||||||
|
num_codebooks: int,
|
||||||
|
codebook_size: int = 256,
|
||||||
|
self_prediction: bool = True,
|
||||||
|
hidden_dim: int = 384,
|
||||||
|
):
|
||||||
|
super(JointCodebookPredictor, self).__init__()
|
||||||
|
|
||||||
|
self.num_codebooks = num_codebooks
|
||||||
|
self.codebook_size = codebook_size
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(predictor_dim, num_codebooks * hidden_dim)
|
||||||
|
|
||||||
|
if self_prediction:
|
||||||
|
linear_self_out_dim = (num_codebooks - 1) * hidden_dim
|
||||||
|
linear_self_in_dim = (num_codebooks - 1) * codebook_size
|
||||||
|
self.linear_self = nn.Parameter(
|
||||||
|
torch.randn(linear_self_out_dim, linear_self_in_dim)
|
||||||
|
* (linear_self_in_dim ** -0.5)
|
||||||
|
)
|
||||||
|
|
||||||
|
# num_codebooks == 3 and hidden_dim == 2 and codebook_size == 2,
|
||||||
|
# the expression below has the value:
|
||||||
|
# tensor([[ True, True, False, False],
|
||||||
|
# [ True, True, False, False],
|
||||||
|
# [ True, True, True, True],
|
||||||
|
# [ True, True, True, True]])
|
||||||
|
self.register_buffer(
|
||||||
|
"linear_self_mask",
|
||||||
|
(
|
||||||
|
(torch.arange(linear_self_out_dim) // hidden_dim).unsqueeze(
|
||||||
|
1
|
||||||
|
)
|
||||||
|
>= (
|
||||||
|
torch.arange(linear_self_in_dim) // codebook_size
|
||||||
|
).unsqueeze(0)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("linear_self", None)
|
||||||
|
self.register_buffer("linear_self_mask", None)
|
||||||
|
|
||||||
|
self.norm = nn.BatchNorm1d(num_codebooks * hidden_dim)
|
||||||
|
self.linear2 = nn.Parameter(
|
||||||
|
torch.randn(num_codebooks, codebook_size, hidden_dim)
|
||||||
|
* (hidden_dim ** -0.5)
|
||||||
|
)
|
||||||
|
self.bias2 = nn.Parameter(torch.zeros(num_codebooks, codebook_size))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, predictor: Tensor, codebook_indexes: Tensor
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Forward function.
|
||||||
|
Args:
|
||||||
|
predictor: a Tensor of some real type, with shape (*, predictor_dim).
|
||||||
|
codebook_indexes: a Tensor of integers, of shape (*, num_codebooks),
|
||||||
|
where the '*' should be the same as for `predictor`. It will be
|
||||||
|
converted to type torch.int64. Should contain indexes of codebook
|
||||||
|
entries, in {0..codebook_size-1},
|
||||||
|
or negative values which will be interpreted as "no codebook index here"
|
||||||
|
(e.g. due to padding); we assume that each frame will either have
|
||||||
|
all-negative or all-nonnegative indexes, meaning that (codebook_indexes >= 0)
|
||||||
|
should not vary as you change the last index into it.
|
||||||
|
Returns: total_logprob, total_count, where:
|
||||||
|
total_logprob: a scalar Tensor, containing the total log-probability of all
|
||||||
|
the nonnegative codebook indexes,
|
||||||
|
total_count: a scalar Tensor containing the total count of nonzero frames,
|
||||||
|
satisfying total_count <= codebook_indexes.numel() / num_groups
|
||||||
|
"""
|
||||||
|
codebook_indexes = codebook_indexes.to(torch.int64)
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
assert list(predictor.shape[:-1]) == list(codebook_indexes.shape[:-1])
|
||||||
|
assert codebook_indexes.shape[-1] == self.num_codebooks
|
||||||
|
|
||||||
|
tot_codebook_dim = self.num_codebooks * self.codebook_size
|
||||||
|
|
||||||
|
common_shape = list(predictor.shape[:-1])
|
||||||
|
codebook_one_hot = torch.zeros(
|
||||||
|
*common_shape,
|
||||||
|
self.num_codebooks,
|
||||||
|
self.codebook_size,
|
||||||
|
device=predictor.device
|
||||||
|
)
|
||||||
|
|
||||||
|
codebook_mask = (codebook_indexes >= 0).unsqueeze(
|
||||||
|
-1
|
||||||
|
) # (*, num_codebooks, 1)
|
||||||
|
codebook_indexes_floor = torch.clamp(codebook_indexes, min=0).unsqueeze(
|
||||||
|
-1
|
||||||
|
) # (*, num_codebooks, 1)
|
||||||
|
codebook_one_hot.scatter_(
|
||||||
|
dim=-1,
|
||||||
|
index=codebook_indexes_floor,
|
||||||
|
src=codebook_mask.to(torch.float32),
|
||||||
|
)
|
||||||
|
codebook_one_hot = codebook_one_hot.reshape(
|
||||||
|
*common_shape, tot_codebook_dim
|
||||||
|
) # (*, tot_codebook_dim)
|
||||||
|
|
||||||
|
hidden = self.linear1(predictor)
|
||||||
|
if self.linear_self is not None:
|
||||||
|
codebook_one_hot_part = torch.narrow(
|
||||||
|
codebook_one_hot, -1, 0, tot_codebook_dim - self.codebook_size
|
||||||
|
)
|
||||||
|
self_predictor = torch.matmul(
|
||||||
|
codebook_one_hot_part,
|
||||||
|
(self.linear_self * self.linear_self_mask).transpose(0, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# add the 'self_predictor' term to all but the 1st
|
||||||
|
# block of "hidden".
|
||||||
|
hidden_part = torch.narrow(
|
||||||
|
hidden,
|
||||||
|
-1,
|
||||||
|
self.hidden_dim,
|
||||||
|
self.hidden_dim * (self.num_codebooks - 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_part += self_predictor
|
||||||
|
|
||||||
|
hidden = nn.functional.relu(hidden)
|
||||||
|
hidden = hidden.reshape(-1, self.hidden_dim * self.num_codebooks)
|
||||||
|
hidden = self.norm(hidden)
|
||||||
|
hidden = hidden.reshape(
|
||||||
|
*common_shape, self.num_codebooks, self.hidden_dim
|
||||||
|
) # (*, num_codebooks, hidden_dim)
|
||||||
|
|
||||||
|
logprobs = (
|
||||||
|
torch.matmul(
|
||||||
|
hidden.unsqueeze(-2), self.linear2.transpose(1, 2)
|
||||||
|
).squeeze(-2)
|
||||||
|
+ self.bias2
|
||||||
|
)
|
||||||
|
|
||||||
|
# logprobs: (*, num_codebooks, codebook_size)
|
||||||
|
logprobs = logprobs.log_softmax(dim=-1)
|
||||||
|
logprobs = logprobs.reshape(
|
||||||
|
*common_shape, self.num_codebooks * self.codebook_size
|
||||||
|
)
|
||||||
|
|
||||||
|
tot_logprob = torch.dot(
|
||||||
|
logprobs.reshape(-1), codebook_one_hot.reshape(-1)
|
||||||
|
)
|
||||||
|
assert tot_logprob <= 0.0
|
||||||
|
# the select() part is to select only the mask for one of the codebooks (they should
|
||||||
|
# all be the same), as we want the total number of frames.
|
||||||
|
tot_count = codebook_mask.select(dim=-2, index=0).sum()
|
||||||
|
|
||||||
|
return (tot_logprob, tot_count)
|
923
egs/librispeech/ASR/conformer_ctc/quantization.py
Normal file
923
egs/librispeech/ASR/conformer_ctc/quantization.py
Normal file
@ -0,0 +1,923 @@
|
|||||||
|
import binascii
|
||||||
|
import h5py
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class Quantizer(nn.Module):
|
||||||
|
# what this is implementing appears to be referred to as direct-sum codebooks in the scientific literature.
|
||||||
|
# see also residual vector quantization, or multistage VQ, although this method jointly optimizes
|
||||||
|
# the codebook entries so there is no order.
|
||||||
|
def __init__(self, dim: int, codebook_size: int, num_codebooks: int):
|
||||||
|
"""
|
||||||
|
Trainable quantizer that encodes a vector into a sequence of integers (corresponding
|
||||||
|
to multiple separate codebooks), aiming to get the least possible expected squared
|
||||||
|
difference.
|
||||||
|
"""
|
||||||
|
super(Quantizer, self).__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.codebook_size = codebook_size
|
||||||
|
self.num_codebooks = num_codebooks
|
||||||
|
|
||||||
|
def is_power_of_two(n: int) -> bool:
|
||||||
|
return (n & (n - 1) == 0) and n != 0
|
||||||
|
|
||||||
|
assert is_power_of_two(codebook_size)
|
||||||
|
assert is_power_of_two(num_codebooks)
|
||||||
|
|
||||||
|
self.to_logits = nn.Linear(dim, codebook_size * num_codebooks)
|
||||||
|
self.logits_scale = 4
|
||||||
|
|
||||||
|
# self.centers: (num_codebooks, codebook_size, dim)
|
||||||
|
self.centers = nn.Parameter(
|
||||||
|
self.to_logits.weight.detach()
|
||||||
|
.clone()
|
||||||
|
.reshape(num_codebooks, codebook_size, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
# We give each Quantizer a unique 8-digit hex identifier, which we'll use to reduce the
|
||||||
|
# probability of mixing up the outputs of different Quantizers.
|
||||||
|
# It is saved as a buffer, as well as a string, so that it will be loaded
|
||||||
|
# from disk when we use load_state_dict().
|
||||||
|
id_bytes = binascii.b2a_hex(
|
||||||
|
os.urandom(4)
|
||||||
|
) # random hex string, e.g. b'585ce3cf'
|
||||||
|
self.id_str = id_bytes.decode("utf-8")
|
||||||
|
self.register_buffer(
|
||||||
|
"id_buf", torch.tensor(list(id_bytes), dtype=torch.uint8)
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_state_dict(self, *args, **kwargs):
|
||||||
|
super(Quantizer, self).load_state_dict(*args, **kwargs)
|
||||||
|
self.id_str = bytes(self.id_buf.tolist()).decode("utf-8")
|
||||||
|
|
||||||
|
def get_id(self) -> str:
|
||||||
|
return self.id_str
|
||||||
|
|
||||||
|
def show_init_invocation(self) -> str:
|
||||||
|
return f"quantization.Quantizer(dim={self.dim}, codebook_size={self.codebook_size}, num_codebooks={self.num_codebooks})"
|
||||||
|
|
||||||
|
def get_data_mean(self) -> Tensor:
|
||||||
|
"""
|
||||||
|
Return an approximate expression for the mean of the training data, as a tensor
|
||||||
|
of shape (dim,). This is useful for diagnostics. It is detached from gradient,
|
||||||
|
to avoid this affecting the optimization.
|
||||||
|
The expression we use assumes balanced codebook probabilities, which is true
|
||||||
|
in practice (as long as index_entropy_loss in training is fairly small).
|
||||||
|
"""
|
||||||
|
return self.centers.mean(dim=1).sum(dim=0).detach()
|
||||||
|
|
||||||
|
def get_product_quantizer(self) -> "Quantizer":
|
||||||
|
"""
|
||||||
|
Returns a Quantizer object with codebook_size = self.codebook_size**2 and
|
||||||
|
num_codebooks = self.num_codebooks//2, initialized so that each codebook
|
||||||
|
in the result is formed from pairs of codebooks in this object.
|
||||||
|
"""
|
||||||
|
new_codebook_size = self.codebook_size ** 2
|
||||||
|
new_num_codebooks = self.num_codebooks // 2
|
||||||
|
|
||||||
|
ans = Quantizer(self.dim, new_codebook_size, new_num_codebooks).to(
|
||||||
|
self.centers.device
|
||||||
|
)
|
||||||
|
|
||||||
|
ans.apply_mask = False
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for c_out in range(new_num_codebooks):
|
||||||
|
c_in1 = 2 * c_out
|
||||||
|
c_in2 = 2 * c_out + 1
|
||||||
|
for k_in1 in range(self.codebook_size):
|
||||||
|
row_in1 = self.codebook_size * c_in1 + k_in1
|
||||||
|
for k_in2 in range(self.codebook_size):
|
||||||
|
row_in2 = self.codebook_size * c_in2 + k_in2
|
||||||
|
k_out = k_in1 * self.codebook_size + k_in2
|
||||||
|
row_out = new_codebook_size * c_out + k_out
|
||||||
|
ans.to_logits.weight[row_out, :] = (
|
||||||
|
self.to_logits.weight[row_in1]
|
||||||
|
+ self.to_logits.weight[row_in2]
|
||||||
|
)
|
||||||
|
ans.to_logits.bias[row_out] = (
|
||||||
|
self.to_logits.bias[row_in1]
|
||||||
|
+ self.to_logits.bias[row_in2]
|
||||||
|
)
|
||||||
|
ans.centers[c_out, k_out, :] = (
|
||||||
|
self.centers[c_in1, k_in1, :]
|
||||||
|
+ self.centers[c_in2, k_in2, :]
|
||||||
|
)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def decode(self, indexes: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Does the (approximate) inverse of _compute_indexes(): constructs from `indexes` the
|
||||||
|
corresponding approximated tensor.
|
||||||
|
Args:
|
||||||
|
indexes:
|
||||||
|
May be an integer tensor of shape (*, self.num_codebooks), with entries
|
||||||
|
in {0..self.num_codebooks-1}
|
||||||
|
May also contain multiple codebook entries combined into one integer, as
|
||||||
|
done by encode() with as_bytes==True; in this case the last dim
|
||||||
|
might be self.num_codebooks/2 or self.num_codebooks/4.
|
||||||
|
Returns: a tensor of shape (*, self.dim), consisting of the sum of the specified
|
||||||
|
cluster centers.
|
||||||
|
"""
|
||||||
|
orig_shape = indexes.shape
|
||||||
|
indexes = indexes.reshape(-1, indexes.shape[-1])
|
||||||
|
indexes = self._maybe_separate_indexes(indexes).to(dtype=torch.int64)
|
||||||
|
|
||||||
|
assert indexes.ndim == 2
|
||||||
|
B = indexes.shape[0]
|
||||||
|
# indexes_expanded: (num_codebooks, B, dim)
|
||||||
|
indexes_expanded = (
|
||||||
|
indexes.transpose(0, 1)
|
||||||
|
.contiguous()
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand(self.num_codebooks, B, self.dim)
|
||||||
|
)
|
||||||
|
# self.centers: (num_codebooks, codebook_size, dim)
|
||||||
|
# chosen_codebooks: (num_codebooks, B, dim).
|
||||||
|
chosen_codebooks = torch.gather(
|
||||||
|
self.centers, dim=1, index=indexes_expanded
|
||||||
|
)
|
||||||
|
|
||||||
|
# x_approx: (B, dim), this is the sum of the chosen rows of `to_output`
|
||||||
|
# corresponding to the chosen codebook entries, this would correspond to
|
||||||
|
# the approximated x.
|
||||||
|
x_approx = chosen_codebooks.sum(dim=0)
|
||||||
|
return x_approx.reshape(*orig_shape[:-1], self.dim)
|
||||||
|
|
||||||
|
def compute_codebook_correlations(self) -> Tensor:
|
||||||
|
"""
|
||||||
|
Return a Tensor of shape (self.num_codebooks, self.num_codebooks)
|
||||||
|
with values >= 0, which are greater if a pair of codebooks more strongly
|
||||||
|
shares a subspace. This is for diagnostic purposes.
|
||||||
|
These correlations are computed by:
|
||||||
|
- subtracting the mean value from each codebook
|
||||||
|
- creating an uncentered variance S_i for each codebook i
|
||||||
|
- computing, for each pair of codebooks i and j, c_{ij} = tr(S_i S_j)
|
||||||
|
- returning c_{ij} / sqrt(c_{ii} c_{ij}), which is a symmetric
|
||||||
|
matrix with values in [0,1]
|
||||||
|
"""
|
||||||
|
centers = self.centers.detach()
|
||||||
|
codebook_means = centers.mean(
|
||||||
|
dim=1, keepdim=True
|
||||||
|
) # (num_codebooks, 1, dim)
|
||||||
|
centers = centers - codebook_means # make each codebook zero-mean.
|
||||||
|
|
||||||
|
# variances: (num_codebooks, dim, dim)
|
||||||
|
variances = torch.matmul(centers.transpose(1, 2), centers)
|
||||||
|
|
||||||
|
# variances_flat: (num_codebooks, dim * dim)
|
||||||
|
variances_flat = variances.reshape(
|
||||||
|
self.num_codebooks, self.dim * self.dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# cross_variances: (num_codebooks, num_codebooks), should be all positive
|
||||||
|
# (interpret these as tr(0.5*(V1 * V2 + V2 * V1)) == tr(V1 * V2) ==
|
||||||
|
# the sum of products of corresponding elements (for this, we use the fact
|
||||||
|
# that V1 and V2 are both symmetric).
|
||||||
|
cross_variances = torch.matmul(variances_flat, variances_flat.t())
|
||||||
|
|
||||||
|
normalizer = cross_variances.diag() ** -0.5
|
||||||
|
normalizer = normalizer.unsqueeze(0) * normalizer.unsqueeze(1)
|
||||||
|
return cross_variances * normalizer
|
||||||
|
|
||||||
|
def compute_loss(self, x: Tensor, refine_indexes_iters: int = 0) -> Tensor:
|
||||||
|
"""
|
||||||
|
Compute various parts of the loss function.
|
||||||
|
Args:
|
||||||
|
x: the Tensor to quantize, of shape (*, dim)
|
||||||
|
refine_indexes_iters: a number >= 0: the number of iterations to refine
|
||||||
|
the indexes from their initial value.
|
||||||
|
Returns: (rel_reconstruction_loss, logprob_loss, entropy_loss, index_entropy_loss), where
|
||||||
|
rel_reconstruction_loss: a scalar torch.Tensor containing the relative sum-squared
|
||||||
|
reconstruction loss, based on the indexes chosen after `refine_indexes_iters`
|
||||||
|
iterations of refinement after the argmax of the logits. This loss is
|
||||||
|
is the sum-squared of (x - reconstructed_x) / (sum-squared of x-x_mean), which
|
||||||
|
for already-trained models will be between 0 and 1, but could be greater than 1
|
||||||
|
at the start of training.
|
||||||
|
logprob_loss: the negative average logprob of the selected classes (i.e. those
|
||||||
|
selected after refine_indexes_iters of refinement). This is added to the
|
||||||
|
loss function, so we can select reasonable classes before refining the indexes.
|
||||||
|
logits_entropy_loss: the class entropy loss, from the logits, which approaches
|
||||||
|
zero when all classes of all codebooks are equi-probable (in the logits output).
|
||||||
|
index_entropy_loss: the class entropy loss, from the computed indexes, which approaches
|
||||||
|
zero when all classes of all codebooks are equi-probable (in the indexes output).
|
||||||
|
Not differentiable but useful for diagnostics.
|
||||||
|
"""
|
||||||
|
x = x.reshape(-1, self.dim)
|
||||||
|
indexes = self._compute_indexes(x, refine_indexes_iters)
|
||||||
|
x_approx = self.decode(indexes)
|
||||||
|
# tot_error: (B, dim), the error of the approximated vs. real x.
|
||||||
|
tot_error = x_approx - x
|
||||||
|
rel_reconstruction_loss = (tot_error ** 2).sum() / (
|
||||||
|
((x - self.get_data_mean()) ** 2).sum() + 1.0e-20
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get logprob loss and class-entropy loss
|
||||||
|
# wasteful.. already computed logits..
|
||||||
|
logits = self._logits(x).reshape(
|
||||||
|
-1, self.num_codebooks, self.codebook_size
|
||||||
|
)
|
||||||
|
logits = logits.log_softmax(dim=2)
|
||||||
|
# chosen_logits: (B, num_codebooks, 1)
|
||||||
|
chosen_logits = torch.gather(logits, dim=2, index=indexes.unsqueeze(2))
|
||||||
|
logprob_loss = -chosen_logits.mean()
|
||||||
|
|
||||||
|
# class_entropy
|
||||||
|
B = x.shape[0]
|
||||||
|
counts = torch.zeros(
|
||||||
|
B, self.num_codebooks, self.codebook_size, device=x.device
|
||||||
|
)
|
||||||
|
ones = torch.ones(1, 1, 1, device=x.device).expand(
|
||||||
|
B, self.num_codebooks, self.codebook_size
|
||||||
|
)
|
||||||
|
counts.scatter_(src=ones, dim=2, index=indexes.unsqueeze(2))
|
||||||
|
avg_counts = counts.mean(dim=0) + 1.0e-20
|
||||||
|
index_entropy = -(avg_counts * avg_counts.log()).sum(dim=1).mean()
|
||||||
|
|
||||||
|
probs = logits.exp().mean(dim=0) + 1.0e-20
|
||||||
|
logits_entropy = -(probs * probs.log()).sum(dim=1).mean()
|
||||||
|
ref_entropy = math.log(self.codebook_size)
|
||||||
|
|
||||||
|
logits_entropy_loss = (ref_entropy - logits_entropy) / ref_entropy
|
||||||
|
index_entropy_loss = (ref_entropy - index_entropy) / ref_entropy
|
||||||
|
|
||||||
|
return (
|
||||||
|
rel_reconstruction_loss,
|
||||||
|
logprob_loss,
|
||||||
|
logits_entropy_loss,
|
||||||
|
index_entropy_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self, x: Tensor, refine_indexes_iters: int = 5, as_bytes: bool = True
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Compute the quantized output, that can be used to reconstruct x.
|
||||||
|
Args:
|
||||||
|
x: the Tensor to quantize, of shape (*, dim)
|
||||||
|
refine_indexes_iters: a number >= 0: the number of iterations to refine
|
||||||
|
the indexes from their initial value.
|
||||||
|
as_bytes: if True, the quantized output will be returned as a byte
|
||||||
|
array, combining as many codes as possible into each bytes
|
||||||
|
codebook_size <= 16.
|
||||||
|
Returns: if as_bytes == False, a torch.LongTensor of shape (*, num_codebooks);
|
||||||
|
if as_bytes == True, a returns a Tensor of dtype=torch.uint8, of shape
|
||||||
|
(*, num_codebooks/n), where n==4 if codebook_size <= 14; or
|
||||||
|
2 if codebook_size <= 16, else 1.
|
||||||
|
"""
|
||||||
|
x_reshaped = x.reshape(-1, self.dim)
|
||||||
|
indexes = self._compute_indexes(x_reshaped, refine_indexes_iters)
|
||||||
|
|
||||||
|
if as_bytes:
|
||||||
|
codebook_size = self.codebook_size
|
||||||
|
while codebook_size ** 2 <= 256:
|
||||||
|
indexes = indexes[:, ::2] + codebook_size * indexes[:, 1::2]
|
||||||
|
codebook_size = codebook_size ** 2
|
||||||
|
assert codebook_size <= 256
|
||||||
|
indexes = indexes.to(torch.uint8)
|
||||||
|
|
||||||
|
return indexes.reshape(*x.shape[:-1], -1)
|
||||||
|
|
||||||
|
def _logits(self, x: Tensor) -> Tensor:
|
||||||
|
return self.to_logits(x) * self.logits_scale
|
||||||
|
|
||||||
|
def _compute_indexes(
|
||||||
|
self, x: Tensor, refine_indexes_iters: int = 3
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Deterministically compute the indexes that encode the tensor x.
|
||||||
|
Args:
|
||||||
|
x: the Tensor to quantize, of shape (B, dim)
|
||||||
|
refine_indexes_iters: a number >= 0: the number of iterations to refine
|
||||||
|
the indexes from their initial value.
|
||||||
|
Returns: returns a torch.LongTensor of shape (B, num_codebooks),
|
||||||
|
with entries in {0..codebook_size-1}
|
||||||
|
"""
|
||||||
|
assert x.ndim == 2 and x.shape[1] == self.dim
|
||||||
|
B = x.shape[0]
|
||||||
|
x_reshaped = x.reshape(-1, self.dim)
|
||||||
|
B = x_reshaped.shape[0]
|
||||||
|
logits = self._logits(x_reshaped)
|
||||||
|
logits = logits.reshape(B, self.num_codebooks, self.codebook_size)
|
||||||
|
|
||||||
|
# indexes: (B, self.num_codebooks)
|
||||||
|
indexes = torch.argmax(logits, dim=-1)
|
||||||
|
for i in range(refine_indexes_iters):
|
||||||
|
indexes = self._refine_indexes(x_reshaped, indexes)
|
||||||
|
assert indexes.ndim == 2
|
||||||
|
return indexes.reshape(*x.shape[:-1], self.num_codebooks)
|
||||||
|
|
||||||
|
def _refine_indexes(self, x: Tensor, indexes: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Refine choices of indexes, minimizing sum-squared loss. Note, this is not guaranteed
|
||||||
|
not not increase the sum-squared loss, but works OK in practice.
|
||||||
|
Args:
|
||||||
|
x: A Tensor of shape (B, self.dim) to be approximated.
|
||||||
|
indexes: A Tensor of integer type, of shape (B, self.num_codebooks),
|
||||||
|
that contains elements in {0..self.codebook_size-1}
|
||||||
|
i: the iteration of refinement (may affect the groups we choose
|
||||||
|
to optimize)
|
||||||
|
Returns: A tensor of indexes of shape (B, self.num_codebooks) that
|
||||||
|
will hopefully reduce the error w.r.t. x, better or at least no worse
|
||||||
|
than `indexes`. This algorithm is not exact, but if the codebooks are
|
||||||
|
fairly orthogonal it should work fine. If they are not fairly orthogonal
|
||||||
|
it may not optimize well, but hopefully the codebooks will then learn
|
||||||
|
to be more orthogonal.
|
||||||
|
"""
|
||||||
|
B = indexes.shape[0]
|
||||||
|
# indexes_expanded has shape (B, self.num_codebooks, 1, self.dim)
|
||||||
|
indexes_expanded = (
|
||||||
|
indexes.unsqueeze(-1)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand(B, self.num_codebooks, 1, self.dim)
|
||||||
|
)
|
||||||
|
# all_centers: (1, num_codebooks, codebook_size, dim)
|
||||||
|
all_centers = self.centers.unsqueeze(0)
|
||||||
|
# centers_expanded has shape (B, self.num_codebooks, self.codebook_size, self.dim)
|
||||||
|
centers_expanded = all_centers.expand(
|
||||||
|
B, self.num_codebooks, self.codebook_size, self.dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# old_centers: (B, self.num_codebooks, 1, self.dim)
|
||||||
|
# centers with the "indexes" as passed-in.
|
||||||
|
old_centers = torch.gather(
|
||||||
|
centers_expanded, dim=2, index=indexes_expanded
|
||||||
|
)
|
||||||
|
# x_err is of shape (B, 1, 1, self.dim), it is the old value of (x_approx - x)
|
||||||
|
x_err = old_centers.sum(dim=1, keepdim=True) - x.unsqueeze(1).unsqueeze(
|
||||||
|
2
|
||||||
|
)
|
||||||
|
|
||||||
|
# The algorithm below is going to be iterative, where at each stage we
|
||||||
|
# have N K-way choices, with each choice corresponding to L codebook indexes.
|
||||||
|
# Initially N == num_codebooks, K == codebook_size, L == 1,
|
||||||
|
# and on the iterations of the algorithm we either:
|
||||||
|
# - terminate by finding the best choice, if currently N == 1
|
||||||
|
# - combine pairs of choices, so that N := N//2, K := K ** 2, L *= 2
|
||||||
|
# - reduce K to K_cutoff by sorting and taking the K_cutoff best possibilities
|
||||||
|
# for each choice. K_cutoff is a power of 2 that starts at 8 or 16
|
||||||
|
# and doubles every 2 iterations to keep the work per iteration
|
||||||
|
# fairly constant.
|
||||||
|
|
||||||
|
# At all points in the algorithm we maintain cur_sumsq and (conceptually)
|
||||||
|
# cur_deltas (however in some parts cur_deltas is not instantiated, see
|
||||||
|
# gather_deltas).
|
||||||
|
#
|
||||||
|
# cur_indexes: (B, N, K, L), initially (B, num_codebooks, codebook_size, 1),
|
||||||
|
# gives the codebook indexes corresponding to the k'th value of the n'th
|
||||||
|
# choice. Initially this is just an arange expression but from the 1st
|
||||||
|
# iter of the algorithm it changes to something nontrivial.
|
||||||
|
#
|
||||||
|
# cur_sumsq: (B, N, K), is the sum-squared error of x versus its predicted value
|
||||||
|
# from the codebooks, if we were to
|
||||||
|
# make the n'th choice with value k without making any of the other N-1 choices, i.e.
|
||||||
|
# if we were to leave the other choices at the value we had at input.
|
||||||
|
# Specifically, it is always supposed to equal the value of
|
||||||
|
# ((x_err + cur_deltas)**2).sum(dim=-1)
|
||||||
|
# .. but we keep it around separately because it enables an optimization.
|
||||||
|
#
|
||||||
|
# cur_deltas: (B, N, K, dim), is the change in x_err (with x_err =
|
||||||
|
# x_approx - x and x_approx being a sum of codebook indexes) if we were
|
||||||
|
# to make the n'th choice with value k without making any of the other
|
||||||
|
# N-1 choices.
|
||||||
|
# At the current point, i.e. at the start of the algorithm,
|
||||||
|
# cur_deltas[b][n][k] says "what would be the change in x_err if we
|
||||||
|
# were to replace the current choice of the n'th codebook entry-- i.e.
|
||||||
|
# the choice reflected in `indexes`-- with value k? [In general,
|
||||||
|
# cur_deltas[b][n][k] refers not directly to a codebook indexes, but
|
||||||
|
# to an indexes into `cur_indexes` which corresponds to the sequence/combination
|
||||||
|
# of codebook indexes that are stored in cur_indexes[b][n][k].
|
||||||
|
|
||||||
|
# cur_deltas represents the change in x_err from making each choice (while
|
||||||
|
# leaving all the other choices un-made by just keeping the passed-in/old
|
||||||
|
# indexes).
|
||||||
|
# cur_deltas: (B, N, K, dim),
|
||||||
|
N = self.num_codebooks
|
||||||
|
K = self.codebook_size
|
||||||
|
L = 1 # L is the number of codebooks covered by each choice.
|
||||||
|
# Conceptually we could do:
|
||||||
|
# cur_deltas = all_centers - old_centers # (B, N, K, dim)
|
||||||
|
# ... however actually we won't be instantiating cur_deltas at this stage of the
|
||||||
|
# algorithm.
|
||||||
|
dim = self.dim
|
||||||
|
|
||||||
|
# cur_indexes is the codebook indexes corresponding to 'cur_deltas'.
|
||||||
|
cur_indexes = (
|
||||||
|
torch.arange(K, device=x.device)
|
||||||
|
.reshape(1, 1, K, 1)
|
||||||
|
.expand(B, N, K, L)
|
||||||
|
)
|
||||||
|
|
||||||
|
if True:
|
||||||
|
# compute cur_sumsq using an efficient approach
|
||||||
|
x_err_sumsq = (x_err ** 2).sum(dim=-1) # (B, 1, 1)
|
||||||
|
|
||||||
|
x_remaining = (
|
||||||
|
x_err - old_centers
|
||||||
|
) # (B, num_codebooks, 1, dim): the x_err after subtracting
|
||||||
|
# each of the codebooks; if we add back to this any given
|
||||||
|
# codebook vector (from all_centers), we'll get the error
|
||||||
|
# if we were to
|
||||||
|
# choose that codebook entry instead of the one actually chosen.
|
||||||
|
|
||||||
|
x_remaining_sumsq = (x_remaining ** 2).sum(
|
||||||
|
dim=-1
|
||||||
|
) # (B, num_codebooks, 1)
|
||||||
|
# all_centers_sumsq is the sumsq of all the centers..
|
||||||
|
all_centers_sumsq = (all_centers ** 2).sum(
|
||||||
|
dim=-1
|
||||||
|
) # (1, num_codebooks, codebook_size)
|
||||||
|
|
||||||
|
cross_sum = torch.matmul(
|
||||||
|
all_centers, # (1, num_codebooks, codebook_size, dim)
|
||||||
|
x_remaining.permute(2, 1, 3, 0), # (1, num_codebooks, dim, B)
|
||||||
|
) # (1, num_codebooks, codebook_size, B)
|
||||||
|
cross_sum = cross_sum.squeeze(0).permute(
|
||||||
|
2, 0, 1
|
||||||
|
) # (B, num_codebooks, codebook_size)
|
||||||
|
# (B, num_codebooks, codebook_size); interpret as (B, N, K)
|
||||||
|
cur_sumsq = x_remaining_sumsq + all_centers_sumsq + 2 * cross_sum
|
||||||
|
assert cur_sumsq.shape == (B, N, K)
|
||||||
|
|
||||||
|
# gather_deltas (which will be re-defined below) is a lambda from
|
||||||
|
# `this_indexes`, a LongTensor of shape (B, N, new_K, 1) [which
|
||||||
|
# at the current iteration would equal (B, num_codebooks, new_K, 1)]
|
||||||
|
# with elements in
|
||||||
|
# {0..K-1} [i.e. 0..codebook_size-1], to the new "cur_deltas".
|
||||||
|
# It is provided as a workaround in
|
||||||
|
# case we did not physically instantiate cur_deltas on this iteration.
|
||||||
|
# In general cur_deltas is supposed to represent "change in encoded
|
||||||
|
# value" if we were to make a particular modified index choice, leaving
|
||||||
|
# all other choices as they were on entry.
|
||||||
|
# gather_deltas is supposed to be a lambda from this_indexes to the
|
||||||
|
# something equivalent to following expression (if cur_deltas had actually
|
||||||
|
# existed):
|
||||||
|
# torch.gather(input=cur_deltas, dim=2, index=this_indexes.expand(B, N, new_K, dim))
|
||||||
|
|
||||||
|
gather_deltas = lambda this_indexes: (
|
||||||
|
torch.gather(
|
||||||
|
input=all_centers.expand(B, N, K, dim),
|
||||||
|
dim=2,
|
||||||
|
index=this_indexes.expand(B, N, -1, dim),
|
||||||
|
)
|
||||||
|
- old_centers
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cur_deltas = all_centers - old_centers # (B, N, K, dim)
|
||||||
|
## cur_sumsq: (B, N, K), equivalent to: ((x_err + cur_deltas)**2).sum(dim=-1)
|
||||||
|
## We really want batched vector-vector product her, which torch does not
|
||||||
|
## explicitly support, so we use a matrix multiplication with 1x1 output.
|
||||||
|
modified_err = x_err + cur_deltas # (B, N, K, dim)
|
||||||
|
cur_sumsq = (
|
||||||
|
torch.matmul(
|
||||||
|
modified_err.unsqueeze(-2), modified_err.unsqueeze(-1)
|
||||||
|
)
|
||||||
|
.squeeze(-1)
|
||||||
|
.squeeze(-1)
|
||||||
|
)
|
||||||
|
gather_deltas = None
|
||||||
|
|
||||||
|
# x_err_sumsq: (B, 1, 1), is the sum-squared of x_err; we'll need it in the loop.
|
||||||
|
x_err_sumsq = (x_err ** 2).sum(dim=-1)
|
||||||
|
|
||||||
|
K_cutoff_base = 8 if self.codebook_size <= 16 else 16
|
||||||
|
|
||||||
|
def get_K_cutoff():
|
||||||
|
# Every time L increases by 4, we double K_cutoff. This keeps the
|
||||||
|
# work per iteration roughly constant, as it's linear in 1/L
|
||||||
|
# and in K_cutoff**2.
|
||||||
|
K_cutoff, l = K_cutoff_base, L
|
||||||
|
while l >= 4:
|
||||||
|
l /= 4
|
||||||
|
K_cutoff *= 2
|
||||||
|
return min(K_cutoff, 128)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
K_cutoff = get_K_cutoff()
|
||||||
|
|
||||||
|
if N == 1 and K == 1:
|
||||||
|
return cur_indexes.squeeze(2).squeeze(
|
||||||
|
1
|
||||||
|
) # (B, L) == (B, num_codebooks)
|
||||||
|
elif K > K_cutoff or N == 1:
|
||||||
|
# Sort the options for each choice, and reduce K.
|
||||||
|
# this_indexes: (B, N, K); elements in {0..K-1}. These
|
||||||
|
# are sorted from best (lowest) to worst.
|
||||||
|
_, this_indexes = torch.sort(cur_sumsq, dim=2)
|
||||||
|
|
||||||
|
new_K = 1 if N == 1 else K_cutoff
|
||||||
|
this_indexes = this_indexes[:, :, :new_K]
|
||||||
|
cur_sumsq = torch.gather(
|
||||||
|
input=cur_sumsq, dim=2, index=this_indexes
|
||||||
|
)
|
||||||
|
|
||||||
|
this_indexes = this_indexes.unsqueeze(-1)
|
||||||
|
|
||||||
|
# cur_indexes is (B, N, new_K, L), but with only the chosen
|
||||||
|
# indexes kept.
|
||||||
|
cur_indexes = torch.gather(
|
||||||
|
input=cur_indexes,
|
||||||
|
dim=2,
|
||||||
|
index=this_indexes.expand(B, N, new_K, L),
|
||||||
|
)
|
||||||
|
|
||||||
|
if gather_deltas is None:
|
||||||
|
# also sort cur_deltas in the same way
|
||||||
|
cur_deltas = torch.gather(
|
||||||
|
input=cur_deltas,
|
||||||
|
dim=2,
|
||||||
|
index=this_indexes.expand(B, N, new_K, dim),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# gather_deltas should be a lambda from:
|
||||||
|
# this_indexes: a LongTensor of shape (B, N, new_K, 1) containing elements in {0..K-1}
|
||||||
|
# to the new "deltas" which should be of shape
|
||||||
|
# (B, N, new_K, dim)
|
||||||
|
# representing the difference from the baseline "x_offset" if we choose this
|
||||||
|
# index for this codebook or range of codebooks, leaving other choices
|
||||||
|
# as they were at entry to this function.
|
||||||
|
cur_deltas = gather_deltas(this_indexes)
|
||||||
|
gather_deltas = None
|
||||||
|
K = new_K
|
||||||
|
else:
|
||||||
|
# Combine pairs of choices. We know that N > 1.
|
||||||
|
even_deltas = cur_deltas[:, 0::2, :, :]
|
||||||
|
odd_deltas = cur_deltas[:, 1::2, :, :]
|
||||||
|
even_indexes = cur_indexes[:, 0::2, :, :]
|
||||||
|
odd_indexes = cur_indexes[:, 1::2, :, :]
|
||||||
|
even_sumsq = cur_sumsq[:, 0::2, :]
|
||||||
|
odd_sumsq = cur_sumsq[:, 1::2, :]
|
||||||
|
|
||||||
|
new_N = N // 2
|
||||||
|
new_K = K ** 2
|
||||||
|
new_L = L * 2
|
||||||
|
|
||||||
|
even_indexes = (
|
||||||
|
even_indexes.unsqueeze(3)
|
||||||
|
.expand(B, new_N, K, K, L)
|
||||||
|
.reshape(B, new_N, new_K, L)
|
||||||
|
)
|
||||||
|
odd_indexes = (
|
||||||
|
odd_indexes.unsqueeze(2)
|
||||||
|
.expand(B, new_N, K, K, L)
|
||||||
|
.reshape(B, new_N, new_K, L)
|
||||||
|
)
|
||||||
|
cur_indexes = torch.cat((even_indexes, odd_indexes), dim=3)
|
||||||
|
|
||||||
|
even_sumsq = even_sumsq.unsqueeze(3) # (B, new_N, K, 1)
|
||||||
|
odd_sumsq = odd_sumsq.unsqueeze(2) # (B, new_N, 1, K)
|
||||||
|
# cur_sumsq below is a partial version, we have to add another term.
|
||||||
|
# The new version of cur_sumsq that we want can be expressed as:
|
||||||
|
# ((a + b + c)**2).sum(dim=-1),
|
||||||
|
# where a = x_err, b == even_deltas, c == odd_deltas. Ignoring the summation, we
|
||||||
|
# can write this as:
|
||||||
|
# a^2 + b^2 + c^2 + 2ab + 2ac + 2bc.
|
||||||
|
# We can rearrange this as:
|
||||||
|
# (a^2 + b^2 + 2ab) + (a^2 + c^2 + 2ac) - a^2 + 2bc,
|
||||||
|
# which is the same as
|
||||||
|
# even_sumsq + odd_sumsq - x_err_sumsq + 2bc,
|
||||||
|
# where 2bc is a certain matrix product of odd_deltas and even_deltas.
|
||||||
|
cur_sumsq = (
|
||||||
|
(even_sumsq + odd_sumsq).reshape(B, new_N, new_K)
|
||||||
|
- x_err_sumsq
|
||||||
|
+ 2
|
||||||
|
* torch.matmul(
|
||||||
|
even_deltas, odd_deltas.transpose(2, 3)
|
||||||
|
).reshape(B, new_N, new_K)
|
||||||
|
)
|
||||||
|
|
||||||
|
saved_K = K
|
||||||
|
gather_deltas = lambda this_indexes: (
|
||||||
|
torch.gather(
|
||||||
|
input=even_deltas,
|
||||||
|
dim=2,
|
||||||
|
index=(this_indexes // saved_K).expand(
|
||||||
|
*this_indexes.shape[:-1], dim
|
||||||
|
),
|
||||||
|
)
|
||||||
|
+ torch.gather(
|
||||||
|
input=odd_deltas,
|
||||||
|
dim=2,
|
||||||
|
index=(this_indexes % saved_K).expand(
|
||||||
|
*this_indexes.shape[:-1], dim
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
cur_deltas = None # Unset it, it is now invalid, but we'll reconstruct it using gather_deltas.
|
||||||
|
|
||||||
|
N, K, L = new_N, new_K, new_L
|
||||||
|
assert cur_indexes.shape == (B, N, K, L)
|
||||||
|
assert cur_sumsq.shape == (B, N, K)
|
||||||
|
|
||||||
|
def _maybe_separate_indexes(self, indexes: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
This reverses the process done in encode() if as_bytes==True, which combines
|
||||||
|
multiple codebook entries into a single byte if self.codebook_size is small
|
||||||
|
enough.
|
||||||
|
Args:
|
||||||
|
indexes: an integer tensor of shape (B, n) where n divides
|
||||||
|
self.num_codebooks
|
||||||
|
Returns: a tensor of the same type as `indexes`, of shape (B,
|
||||||
|
self.num_codebooks)
|
||||||
|
"""
|
||||||
|
B = indexes.shape[0]
|
||||||
|
if indexes.shape[-1] != self.num_codebooks:
|
||||||
|
n = indexes.shape[-1]
|
||||||
|
num_repeats = self.num_codebooks // n
|
||||||
|
assert (
|
||||||
|
num_repeats in [2, 4, 8, 16]
|
||||||
|
and self.num_codebooks == n * num_repeats
|
||||||
|
)
|
||||||
|
indexes = indexes.unsqueeze(2).expand(B, n, num_repeats)
|
||||||
|
size = self.codebook_size
|
||||||
|
indexes = (
|
||||||
|
indexes
|
||||||
|
// (size ** torch.arange(num_repeats, device=indexes.device))
|
||||||
|
) % size
|
||||||
|
indexes = indexes.reshape(B, self.num_codebooks)
|
||||||
|
assert indexes.shape == (B, self.num_codebooks)
|
||||||
|
return indexes
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizerTrainer(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
bytes_per_frame: int,
|
||||||
|
device: torch.device,
|
||||||
|
phase_one_iters: int = 10000,
|
||||||
|
phase_two_iters: int = 10000,
|
||||||
|
lr: float = 0.005,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim: The feature dimension we are trying to quantize, e.g. 512
|
||||||
|
bytes_per_frame: The number of bytes to use to quantize each vector of
|
||||||
|
`dim` values.
|
||||||
|
device: The device to use for training
|
||||||
|
phase_one_iters: The number of iterations to use for the first
|
||||||
|
phase of training (with codebook_size=16); after this we
|
||||||
|
will convert to have codebook_size=256. These parameters were
|
||||||
|
tuned with a batch size of 600: if your batch size (in frames)
|
||||||
|
is smaller than this you may benefit from a larger phase_one_iters and a
|
||||||
|
smaller learning rate.
|
||||||
|
[Also, note: phase_one_iters should be larger for larger dims;
|
||||||
|
for dim=256 and batch_size=600, 10k was enough, but for
|
||||||
|
dim=512 and batch_size=600, 20k was better.
|
||||||
|
phase_two_iters: The number of iterations to use for the second
|
||||||
|
phase of training (with codebook_size=256)
|
||||||
|
lr: The initial learning rate.
|
||||||
|
This object trains a Quantizer. You can use it as follows:
|
||||||
|
trainer = QuantizerTrainer(...)
|
||||||
|
while not trainer.done():
|
||||||
|
# let x be some tensor of shape (*, dim), that you will train on
|
||||||
|
# (should not be the same on each minibatch)
|
||||||
|
trainer.step(x)
|
||||||
|
quantizer = trainer.get_quantizer()
|
||||||
|
"""
|
||||||
|
super(QuantizerTrainer, self).__init__()
|
||||||
|
assert bytes_per_frame in [1, 2, 4, 8, 16, 32]
|
||||||
|
|
||||||
|
# We'll initially train with codebook_size=16 and
|
||||||
|
# num_codebooks=bytes_per_frame * 2, then after `phase_one_iters` of
|
||||||
|
# training will multiply pairs of codebooks so codebook_size=256 and
|
||||||
|
# num_codebooks=bytes_per_frame
|
||||||
|
|
||||||
|
self.phase_one_iters = phase_one_iters
|
||||||
|
self.phase_two_iters = phase_two_iters
|
||||||
|
self.cur_iter = 0
|
||||||
|
self.lr = lr
|
||||||
|
self.two_iter_prob = 0.5
|
||||||
|
|
||||||
|
self.quantizer = Quantizer(
|
||||||
|
dim=dim, codebook_size=16, num_codebooks=bytes_per_frame * 2
|
||||||
|
).to(device)
|
||||||
|
self.start_time = time.time()
|
||||||
|
self._init_optimizer()
|
||||||
|
|
||||||
|
def done(self) -> bool:
|
||||||
|
ans = self.cur_iter > self.phase_one_iters + self.phase_two_iters
|
||||||
|
if ans:
|
||||||
|
elapsed_time = time.time() - self.start_time
|
||||||
|
logging.info(
|
||||||
|
f"Elapsed time, training model of dim={self.quantizer.dim}, num_codebooks={self.quantizer.num_codebooks}, "
|
||||||
|
f"codebook_size={self.quantizer.codebook_size}, is: {elapsed_time:.2f} seconds."
|
||||||
|
)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def step(self, x: torch.Tensor) -> None:
|
||||||
|
"""
|
||||||
|
Does one step of training. You must call this at least 2*phase_one_iters
|
||||||
|
iterations.
|
||||||
|
Args:
|
||||||
|
x: a Tensor of shape (*, dim) containing the frames of data we are
|
||||||
|
trying to accurately encode.
|
||||||
|
"""
|
||||||
|
x = x.reshape(-1, self.quantizer.dim)
|
||||||
|
|
||||||
|
num_iters = 2 if random.random() < self.two_iter_prob else 1
|
||||||
|
(
|
||||||
|
reconstruction_loss,
|
||||||
|
logprob_loss,
|
||||||
|
logits_entropy_loss,
|
||||||
|
index_entropy_loss,
|
||||||
|
) = self.quantizer.compute_loss(x, num_iters)
|
||||||
|
|
||||||
|
if self.cur_iter % 200 == 0:
|
||||||
|
det_losses = [
|
||||||
|
float("%.3f" % self.quantizer.compute_loss(x, j)[0].item())
|
||||||
|
for j in range(6)
|
||||||
|
]
|
||||||
|
phase = 1 if self.cur_iter <= self.phase_one_iters else 2
|
||||||
|
i = (
|
||||||
|
self.cur_iter - self.phase_one_iters
|
||||||
|
if phase > 1
|
||||||
|
else self.cur_iter
|
||||||
|
)
|
||||||
|
# Caution: python's logging level is logging.ERROR by default. To make the following
|
||||||
|
# be printed, you may have to do:
|
||||||
|
# import logging
|
||||||
|
# logging.getLogger().setLevel(logging.INFO)
|
||||||
|
# before using this code.
|
||||||
|
logging.info(
|
||||||
|
f"phase={phase}/2, iter={i}, "
|
||||||
|
f"dim,nc,csz={self.quantizer.dim},{self.quantizer.num_codebooks},{self.quantizer.codebook_size}, "
|
||||||
|
f"loss_per_iter={det_losses}, "
|
||||||
|
f"logprob_loss={logprob_loss.item():.3f}, "
|
||||||
|
f"logits_entropy_loss={logits_entropy_loss.item():.3f}, "
|
||||||
|
f"index_entropy_loss={index_entropy_loss.item():.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.cur_iter % 2000 == 0 and self.cur_iter > 0:
|
||||||
|
correlations = self.quantizer.compute_codebook_correlations()
|
||||||
|
logging.info(f"correlations = {correlations}")
|
||||||
|
|
||||||
|
# We did not find it necessary to use entropy_scale -- the
|
||||||
|
# logits_entropy_loss and index_entropy_loss are less than 0.01 even
|
||||||
|
# with entropy_scale == 0 -- but we are putting a nonzero value on
|
||||||
|
# entropy_scale just out of an abundance of caution, in case an unusual
|
||||||
|
# data distribution might cause problems in the future.
|
||||||
|
entropy_scale = 0.01
|
||||||
|
|
||||||
|
# About the losses:
|
||||||
|
# - reconstruction_loss >= 0; it equals 0 when reconstruction is exact.
|
||||||
|
# This is the main loss function, used to train quantizer.centers
|
||||||
|
# - logprob_loss trains only quantizer.to_logits, which predicts the
|
||||||
|
# indexes after refinement, so we can initialize them well; it does
|
||||||
|
# not affect the cluster centers.
|
||||||
|
# - logits_entropy_loss is currently not used for training, since we
|
||||||
|
# set entropy_scale = 0 above. It would affect only to_logits, if
|
||||||
|
# used. The intention was that this might solve problem with
|
||||||
|
# cluster centers having very uneven probabilities of being chosen
|
||||||
|
# (it would do this by biasing the initial choice, relying on
|
||||||
|
# the inexactness of the search). In our experiments,
|
||||||
|
# logits entropy_loss and index_entropy_loss both end up
|
||||||
|
# less than 0.05, so this does not seem to be a problem in practice,
|
||||||
|
# but it might be a problem if, say, the inputs had a very tiny scale,
|
||||||
|
# so we are keeping the code around.
|
||||||
|
# - index_entropy_loss is not differentiable; we have
|
||||||
|
# added it only for diagnostic purposes. It reflects the entropy of
|
||||||
|
# the distribution over classes, after refining the cluster indexes.
|
||||||
|
# It was computed just in case regularizing logits_entropy_loss was
|
||||||
|
# not enough to affect the final distribution over cluster centers,
|
||||||
|
# so we could diagnose the problem;
|
||||||
|
# but we found no problem in practice.
|
||||||
|
#
|
||||||
|
|
||||||
|
tot_loss = (
|
||||||
|
reconstruction_loss
|
||||||
|
+ logprob_loss
|
||||||
|
+ logits_entropy_loss * entropy_scale
|
||||||
|
)
|
||||||
|
# We want to maximize frame_entropy if it is less than frame_entropy_cutoff.
|
||||||
|
# tot_loss -= torch.minimum(self.frame_entropy_cutoff,
|
||||||
|
# frame_entropy)
|
||||||
|
|
||||||
|
tot_loss.backward()
|
||||||
|
self.optim.step()
|
||||||
|
self.optim.zero_grad()
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
if self.cur_iter == self.phase_one_iters:
|
||||||
|
self._begin_second_phase()
|
||||||
|
self.cur_iter += 1
|
||||||
|
|
||||||
|
def _init_optimizer(self):
|
||||||
|
self.optim = torch.optim.Adam(
|
||||||
|
self.quantizer.parameters(),
|
||||||
|
lr=self.lr,
|
||||||
|
betas=(0.9, 0.98),
|
||||||
|
eps=1e-9,
|
||||||
|
weight_decay=1.0e-06,
|
||||||
|
)
|
||||||
|
self.scheduler = torch.optim.lr_scheduler.StepLR(
|
||||||
|
self.optim,
|
||||||
|
step_size=(
|
||||||
|
self.phase_one_iters
|
||||||
|
if self.cur_iter == 0
|
||||||
|
else self.phase_two_iters
|
||||||
|
)
|
||||||
|
/ 4,
|
||||||
|
gamma=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _begin_second_phase(self):
|
||||||
|
"""
|
||||||
|
This is to be called exactly once, when self.cur_iter reaches self.phase_one_iters
|
||||||
|
"""
|
||||||
|
self.quantizer = self.quantizer.get_product_quantizer()
|
||||||
|
self.lr *= 0.5
|
||||||
|
self._init_optimizer()
|
||||||
|
|
||||||
|
def get_quantizer(self) -> Quantizer:
|
||||||
|
assert self.cur_iter >= self.phase_one_iters + self.phase_two_iters
|
||||||
|
return self.quantizer
|
||||||
|
|
||||||
|
|
||||||
|
def read_hdf5_data(filename: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Reads the hdf5 archive in the file with name 'filename' into a single
|
||||||
|
numpy array of size (tot_frames, dim), shuffles the frames, and returns it
|
||||||
|
as a numpy array.
|
||||||
|
The type will be the same as it was in the archive (e.g. float16).
|
||||||
|
Args:
|
||||||
|
filename: the name of the filename of your hdf5 archive. It should
|
||||||
|
have been created using code similar to the code in test_write_hdf5.py,
|
||||||
|
e.g. something like:
|
||||||
|
hf = h5py.File(filename, 'w')
|
||||||
|
for i in range(...):
|
||||||
|
# get x as some numpy array of type np.float16, and shape (*, dim)
|
||||||
|
# the name does not actually matter,
|
||||||
|
# except that they should be distinct.
|
||||||
|
hf.create_dataset(f'dataset_{i}', data=x)
|
||||||
|
Returns (train, valid), where:
|
||||||
|
train: a torch.Tensor of shape (tot_train_frames, dim), on CPU, with
|
||||||
|
dtype=torch.float16, with shuffled rows.
|
||||||
|
valid: a torch.Tensor of shape (tot_valid_frames, dim), on CPU, with
|
||||||
|
dtype=torch.float16, with shuffled rows (these are distinct
|
||||||
|
frames from those in `train`, but may derive from diffrent
|
||||||
|
rows of the same original tensors.)
|
||||||
|
Caution: you should set the logger to INFO level, with:
|
||||||
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
|
if you want to see the logging output of this function.
|
||||||
|
"""
|
||||||
|
logging.info(f"Opening file {filename}")
|
||||||
|
hf = h5py.File(filename, "r")
|
||||||
|
tot_frames = 0
|
||||||
|
dim = -1
|
||||||
|
|
||||||
|
def get_num_frames(shape):
|
||||||
|
# Returns product of shape[0],shape[1],...,shape[-2]
|
||||||
|
num_frames = 1
|
||||||
|
for i in shape[:-1]:
|
||||||
|
num_frames *= i
|
||||||
|
return num_frames
|
||||||
|
|
||||||
|
for key in hf.keys():
|
||||||
|
dset = hf[key]
|
||||||
|
shape = list(dset.shape)
|
||||||
|
if dim == -1:
|
||||||
|
dim = shape[-1]
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
dim == shape[-1]
|
||||||
|
), "Dataset must have consistent dimension (last element of shape"
|
||||||
|
tot_frames += get_num_frames(shape)
|
||||||
|
logging.info(f"read_data: tot_frames = {tot_frames}")
|
||||||
|
|
||||||
|
ans = np.empty((tot_frames, dim), dtype=np.float16)
|
||||||
|
cur_pos = 0
|
||||||
|
for key in hf.keys():
|
||||||
|
array = hf[key][:] # [:] gets it as NumPy array (I believe).
|
||||||
|
array = np.ascontiguousarray(array).reshape(-1, dim)
|
||||||
|
num_frames = array.shape[0]
|
||||||
|
ans[cur_pos : cur_pos + num_frames, :] = array # noqa E203
|
||||||
|
cur_pos += num_frames
|
||||||
|
assert cur_pos == tot_frames
|
||||||
|
|
||||||
|
# Shuffle the rows of ans.
|
||||||
|
np.random.shuffle(ans)
|
||||||
|
ans_torch = torch.from_numpy(ans)
|
||||||
|
|
||||||
|
valid_proportion = 0.05
|
||||||
|
valid_frames = valid_proportion * tot_frames
|
||||||
|
if valid_frames > 10000:
|
||||||
|
valid_frames = 10000
|
||||||
|
train_frames = tot_frames - valid_frames
|
||||||
|
logging.info(
|
||||||
|
f"read_data: train_frames={train_frames}, valid_frames={valid_frames}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# return (train, valid)
|
||||||
|
return ans_torch[valid_frames:tot_frames], ans_torch[:valid_frames]
|
Loading…
x
Reference in New Issue
Block a user