copy quantization files from dan's repo

This commit is contained in:
Guo Liyong 2021-12-23 18:33:39 +08:00
parent 3570cb738a
commit 8985440ce1
6 changed files with 1632 additions and 2 deletions

View File

@ -1,2 +1,5 @@
egs/librispeech/ASR/conformer_ctc/test_label_smoothing.py
egs/librispeech/ASR/conformer_ctc/test_subsampling.py
egs/librispeech/ASR/conformer_ctc/checkpoint.py
egs/librispeech/ASR/conformer_ctc/ckpnt_prediction.py
egs/librispeech/ASR/conformer_ctc/powerful_prediction.py
egs/librispeech/ASR/conformer_ctc/prediction.py
egs/librispeech/ASR/conformer_ctc/quantization.py

View File

@ -0,0 +1,78 @@
import torch
from torch import nn
from torch import Tensor
from typing import Tuple, Callable
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, function: Callable, *args):
# `function` must return either a Tensor or a tuple of Tensors
ctx.function = function
ctx.args = [x.detach() if isinstance(x, Tensor) else x for x in args]
for i in range(len(ctx.args)):
if isinstance(args[i], Tensor) and args[i].requires_grad:
ctx.args[i].requires_grad = True
with torch.no_grad():
ans = function(*args)
return ans
@staticmethod
def backward(ctx, *ans_grads):
if not any([a is not None for a in ans_grads]):
return [None] * len(ctx.args)
with torch.enable_grad():
ans = ctx.function(*ctx.args)
if isinstance(ans, Tensor):
assert len(ans_grads) == 1
loss = (ans * ans_grads[0]).sum()
else:
assert len(ans_grads) == len(ans)
loss = torch.stack(
[
(a * g).sum()
for a, g in zip(ans, ans_grads)
if g is not None
]
).sum()
loss.backward()
return tuple(
[None]
+ [a.grad if isinstance(a, Tensor) else None for a in ctx.args]
)
def checkpoint(function, *args):
return CheckpointFunction.apply(function, *args)
def _test1():
x = torch.Tensor([0])
y = torch.Tensor([1])
y.requires_grad = True
l = lambda x, y, trash: torch.stack((x, y))
ans = checkpoint(l, x, y, None)
# ans = l(x, y, None)
print("ans = ", ans)
ans.sum().backward()
print("y grad = ", y.grad)
def _test2():
x = torch.Tensor([0])
y = torch.Tensor([1])
x.requires_grad = True
l = lambda x, y, trash: torch.stack((x, y))
ans = checkpoint(l, x, y, None)
ans = checkpoint(torch.sum, ans)
# ans = l(x, y, None)
print("ans = ", ans)
ans.backward()
print("x grad = ", x.grad)
if __name__ == "__main__":
_test1()
_test2()

View File

@ -0,0 +1,211 @@
import torch
from torch import nn
from torch import Tensor
from typing import Tuple, Optional
from checkpoint import (
checkpoint,
) # from current directory.. could not get relative import to work..
# functional version of joint codebook loss, added so that we can more easily implement
# checkpointing to save memory.
def joint_codebook_loss(
predictor: Tensor,
codebook_indexes: Tensor,
linear1_weight: Tensor,
linear1_bias: Optional[Tensor],
codebook_embedding_weight: Tensor,
linear2_weight: Tensor,
linear2_bias: Tensor,
ignore_index: int,
reduction: str,
) -> Tensor:
"""
Args:
predictor: predictor tensor of shape (*, predictor_channels)
codebook_indexes: codebook indexes of shape (*, num_codebooks)
linear1_weight: weight of shape (hidden_channels, predictor_channels)
linear1_bias: optional bias of shape (hidden_channels,)
codebook_embedding_weight: weight of shape ((num_codebooks - 1) * codebook_size,
hidden_channels)
linear2_weight: weight of shape (num_codebooks, codebook_size,
hidden_channels)
linear2_bias: bias of shape (num_codebooks, codebook_size)
ignore_index: index to ignore in cross entropy loss, e.g. -100
reduction: reduction in cross entropy loss, e.g. 'sum'
"""
num_codebooks = codebook_indexes.shape[-1]
predictor_channels = predictor.shape[-1]
hidden_channels = linear1_weight.shape[0]
codebook_size = codebook_embedding_weight.shape[0] // (num_codebooks - 1)
codebook_indexes = codebook_indexes.to(torch.int64)
assert list(predictor.shape[:-1]) == list(codebook_indexes.shape[:-1])
predictor = predictor.reshape(
-1, predictor.shape[-1]
) # (N, predictor_channels)
codebook_indexes = codebook_indexes.reshape(-1, codebook_indexes.shape[-1])
first_indexes = codebook_indexes[
:, :-1
] # all but last codebook indexes; (N, num_codebooks-1)
# do clamp(min=0) to avoid errors on padding (-100).. these frames will
# later be ignored in the loss, so the value can be treated as a don't-care.
first_indexes = first_indexes.clamp(min=0) + torch.arange(
0,
(num_codebooks - 1) * codebook_size,
step=codebook_size,
device=first_indexes.device,
) # (N, num_codebooks-1)
first_embeddings = torch.nn.functional.embedding(
first_indexes, codebook_embedding_weight
) * (
hidden_channels ** 0.5
) # (N, num_codebooks-1, hidden_channels)
hidden_predictor = torch.nn.functional.linear(
predictor, linear1_weight, linear1_bias
)
all_embeddings = torch.cat(
(hidden_predictor.unsqueeze(1), first_embeddings), dim=1
) # (N, num_codebooks, hidden_channels)
# after cumsum, all positions will contain a contribution from 'hidden_predictor'; and
# will also contain contributions from all *previous* codebooks. Here, "position" means
# a position in {0..num_codebooks-1}
all_embeddings = torch.cumsum(
all_embeddings, dim=1
) # (N, num_codebooks, hidden_channels)
all_embeddings = torch.nn.functional.relu(all_embeddings)
logprobs = torch.matmul(
all_embeddings.transpose(0, 1), # (num_codebooks, N, hidden_channels)
linear2_weight.transpose(
1, 2
), # (num_codebooks, hidden_channels, codebook_size)
).transpose(
0, 1
) # (N, num_codebooks, codebook_size)
logprobs += linear2_bias
logprobs = logprobs.log_softmax(dim=2) # (N, num_codebooks, codebook_size)
return torch.nn.functional.cross_entropy(
logprobs.reshape(-1, codebook_size),
codebook_indexes.reshape(-1),
ignore_index=ignore_index,
reduction=reduction,
)
class JointCodebookLoss(nn.Module):
"""
This module predicts a group of codebook indexes from a vector. The idea is that
you have a number of codebooks (probably jointly trained), from class Quantizer,
and you want to predict the probabilities of the codebook entries based on some
predictor that you are training.
The simplest thing would be to project the vector using nn.Linear, then
reshape and use logsoftmax to normalize the probabilities within each group,
then compute the likelihood. However, this has a constraint that all the
codebooks are predicted independently of each other. This module allows you
to predict them jointly, by regressing each codebook on all previous codebooks.
This is done with a nonlinearity in which the previous codebook entries are combined
with the input predictor vector, so that the regression is not purely
linear.
Args:
predictor_dim: the number of features that we use to predict the codebook
indexes, e.g. 2048 (will depend on your model).
hidden_dim: a hidden dimension in the model; should be more than
codebook_size, but may be less or more than predictor_dim.
num_codebooks: the number of codebooks that you are predicting;
will likely be the same as the bytes_per_frame given to the
QuantizerTrainer that you used to train the Quantizer you
are predicting.
codebook_size: number of entries per codebook (often 256)
self_prediction: you can set this to false to enable prediction of
codebooks by earlier-numbered codebooks
hidden_dim: the hidden dimension per codebook (we use a 1-hidden-layer
network, with a ReLU and then batchnorm).
checkpoint: if true, reduce backprop memory at the expense of doing
the computation twice.
"""
def __init__(
self,
predictor_channels: int,
num_codebooks: int,
hidden_channels: int = 512,
codebook_size: int = 256,
reduction: str = "sum",
ignore_index: int = -100,
checkpoint: bool = True,
):
super(JointCodebookLoss, self).__init__()
assert num_codebooks > 1 # we may later handle this specially.
self.num_codebooks = num_codebooks
self.codebook_size = codebook_size
self.hidden_channels = hidden_channels
self.ignore_index = ignore_index
self.reduction = reduction
self.checkpoint = checkpoint
self.linear1 = nn.Linear(predictor_channels, hidden_channels)
# codebook_embedding is used to predict each codebook from previous
# codebooks, so it's a joint, not independent, model. we'll multiply
# this by hidden_channels ** 0.5 when we use it; this keeps the magnitude
# small allows it to train fast enough (relatively speaking).
self.codebook_embedding = nn.Embedding(
(num_codebooks - 1) * codebook_size,
hidden_channels,
_weight=torch.randn(
(num_codebooks - 1) * codebook_size, hidden_channels
)
* (hidden_channels ** -0.5),
)
self.nonlin = nn.ReLU(inplace=True)
self.linear2_weight = nn.Parameter(
torch.randn(num_codebooks, codebook_size, hidden_channels)
* (hidden_channels ** -0.5)
)
self.linear2_bias = nn.Parameter(
torch.zeros(num_codebooks, codebook_size)
)
def forward(
self, predictor: Tensor, codebook_indexes: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Forward function.
Args:
predictor: a Tensor of some real type, with shape (*, predictor_channels).
codebook_indexes: a Tensor of integers, of shape (*, num_codebooks),
where the '*' should be the same as for `predictor`. It will be
converted to type torch.int64. Should contain indexes of codebook
entries, in {0..codebook_size-1},
or negative values which will be interpreted as "no codebook index here"
(e.g. due to padding); we assume that each frame will either have
all-negative or all-nonnegative indexes, meaning that (codebook_indexes >= 0)
should not vary as you change the last index into it.
Returns:
cross_entropy_loss, will be a total negated log-probability, assuming
reduction == 'sum'.
"""
args = (
predictor,
codebook_indexes,
self.linear1.weight,
self.linear1.bias,
self.codebook_embedding.weight,
self.linear2_weight,
self.linear2_bias,
self.ignore_index,
self.reduction,
)
if self.checkpoint:
return checkpoint(joint_codebook_loss, *args)
else:
return joint_codebook_loss(*args)

View File

@ -0,0 +1,228 @@
import torch
from torch import nn
from torch import Tensor
from typing import Tuple, Optional
from checkpoint import (
checkpoint,
) # from current directory.. could not get relative import to work..
# functional version of joint codebook loss, added so that we can more easily implement
# checkpointing to save memory.
def joint_codebook_loss(
predictor: Tensor,
codebook_indexes: Tensor,
linear1_weight: Tensor,
linear1_bias: Optional[Tensor],
codebook_embedding_weight: Tensor,
linear2_weight: Tensor,
linear2b_weight: Tensor,
linear2_bias: Tensor,
ignore_index: int,
reduction: str,
) -> Tensor:
"""
Args:
predictor: predictor tensor of shape (*, predictor_channels)
codebook_indexes: codebook indexes of shape (*, num_codebooks)
linear1_weight: weight of shape (hidden_channels, predictor_channels)
linear1_bias: optional bias of shape (hidden_channels,)
codebook_embedding_weight: weight of shape ((num_codebooks - 1) * codebook_size,
hidden_channels)
linear2_weight: weight of shape (num_codebooks, codebook_size,
hidden_channels)
linear2b_weight: weight of shape (num_codebooks, codebook_size,
predictor_dim)
linear2_bias: bias of shape (num_codebooks, codebook_size)
ignore_index: index to ignore in cross entropy loss, e.g. -100
reduction: reduction in cross entropy loss, e.g. 'sum'
"""
num_codebooks = codebook_indexes.shape[-1]
predictor_channels = predictor.shape[-1]
hidden_channels = linear1_weight.shape[0]
codebook_size = codebook_embedding_weight.shape[0] // (num_codebooks - 1)
codebook_indexes = codebook_indexes.to(torch.int64)
assert list(predictor.shape[:-1]) == list(codebook_indexes.shape[:-1])
predictor = predictor.reshape(
-1, predictor.shape[-1]
) # (N, predictor_channels)
codebook_indexes = codebook_indexes.reshape(-1, codebook_indexes.shape[-1])
first_indexes = codebook_indexes[
:, :-1
] # all but last codebook indexes; (N, num_codebooks-1)
# do clamp(min=0) to avoid errors on padding (-100).. these frames will
# later be ignored in the loss, so the value can be treated as a don't-care.
first_indexes = first_indexes.clamp(min=0) + torch.arange(
0,
(num_codebooks - 1) * codebook_size,
step=codebook_size,
device=first_indexes.device,
) # (N, num_codebooks-1)
first_embeddings_scale = 0.5 * ((hidden_channels / num_codebooks) ** 0.5)
first_embeddings = (
torch.nn.functional.embedding(first_indexes, codebook_embedding_weight)
* first_embeddings_scale
) # (N, num_codebooks-1, hidden_channels)
hidden_predictor = torch.nn.functional.linear(
predictor, linear1_weight, linear1_bias
)
all_embeddings = torch.cat(
(hidden_predictor.unsqueeze(1), first_embeddings), dim=1
) # (N, num_codebooks, hidden_channels)
# after cumsum, all positions will contain a contribution from 'hidden_predictor'; and
# will also contain contributions from all *previous* codebooks. Here, "position" means
# a position in {0..num_codebooks-1}
all_embeddings = torch.cumsum(
all_embeddings, dim=1
) # (N, num_codebooks, hidden_channels)
all_embeddings = torch.nn.functional.relu(all_embeddings)
logprobs = torch.matmul(
all_embeddings.transpose(0, 1), # (num_codebooks, N, hidden_channels)
linear2_weight.transpose(
1, 2
), # (num_codebooks, hidden_channels, codebook_size)
).transpose(
0, 1
) # (N, num_codebooks, codebook_size)
logprobs += torch.matmul(
predictor, # (N, predictor_channels)
linear2b_weight.transpose(
1, 2
), # (num_codebooks, predictor_channels, codebook_size)
).transpose(
0, 1
) # (N, num_codebooks, codebook_size)
logprobs += linear2_bias
logprobs = logprobs.log_softmax(dim=2) # (N, num_codebooks, codebook_size)
return torch.nn.functional.cross_entropy(
logprobs.reshape(-1, codebook_size),
codebook_indexes.reshape(-1),
ignore_index=ignore_index,
reduction=reduction,
)
class Powerful_JointCodebookLoss(nn.Module):
"""
This module predicts a group of codebook indexes from a vector. The idea is that
you have a number of codebooks (probably jointly trained), from class Quantizer,
and you want to predict the probabilities of the codebook entries based on some
predictor that you are training.
The simplest thing would be to project the vector using nn.Linear, then
reshape and use logsoftmax to normalize the probabilities within each group,
then compute the likelihood. However, this has a constraint that all the
codebooks are predicted independently of each other. This module allows you
to predict them jointly, by regressing each codebook on all previous codebooks.
This is done with a nonlinearity in which the previous codebook entries are combined
with the input predictor vector, so that the regression is not purely
linear.
Args:
predictor_dim: the number of features that we use to predict the codebook
indexes, e.g. 2048 (will depend on your model).
hidden_dim: a hidden dimension in the model; should be more than
codebook_size, but may be less or more than predictor_dim.
num_codebooks: the number of codebooks that you are predicting;
will likely be the same as the bytes_per_frame given to the
QuantizerTrainer that you used to train the Quantizer you
are predicting.
codebook_size: number of entries per codebook (often 256)
self_prediction: you can set this to false to enable prediction of
codebooks by earlier-numbered codebooks
hidden_dim: the hidden dimension per codebook (we use a 1-hidden-layer
network, with a ReLU and then batchnorm).
checkpoint: if true, reduce backprop memory at the expense of doing
the computation twice.
"""
def __init__(
self,
predictor_channels: int,
num_codebooks: int,
hidden_channels: int = 512,
codebook_size: int = 256,
reduction: str = "sum",
ignore_index: int = -100,
checkpoint: bool = True,
):
super(Powerful_JointCodebookLoss, self).__init__()
assert num_codebooks > 1 # we may later handle this specially.
self.num_codebooks = num_codebooks
self.codebook_size = codebook_size
self.hidden_channels = hidden_channels
self.ignore_index = ignore_index
self.reduction = reduction
self.checkpoint = checkpoint
self.linear1 = nn.Linear(predictor_channels, hidden_channels)
# codebook_embedding is used to predict each codebook from previous
# codebooks, so it's a joint, not independent, model. we'll multiply
# this by hidden_channels ** 0.5 when we use it; this keeps the magnitude
# small allows it to train fast enough (relatively speaking).
self.codebook_embedding = nn.Embedding(
(num_codebooks - 1) * codebook_size,
hidden_channels,
_weight=torch.randn(
(num_codebooks - 1) * codebook_size, hidden_channels
)
* (hidden_channels ** -0.5),
)
self.linear2_weight = nn.Parameter(
torch.randn(num_codebooks, codebook_size, hidden_channels)
* (hidden_channels ** -0.5)
)
self.linear2b_weight = nn.Parameter(
torch.randn(num_codebooks, codebook_size, predictor_channels)
* (predictor_channels ** -0.5)
)
self.linear2_bias = nn.Parameter(
torch.zeros(num_codebooks, codebook_size)
)
def forward(
self, predictor: Tensor, codebook_indexes: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Forward function.
Args:
predictor: a Tensor of some real type, with shape (*, predictor_channels).
codebook_indexes: a Tensor of integers, of shape (*, num_codebooks),
where the '*' should be the same as for `predictor`. It will be
converted to type torch.int64. Should contain indexes of codebook
entries, in {0..codebook_size-1},
or negative values which will be interpreted as "no codebook index here"
(e.g. due to padding); we assume that each frame will either have
all-negative or all-nonnegative indexes, meaning that (codebook_indexes >= 0)
should not vary as you change the last index into it.
Returns:
cross_entropy_loss, will be a total negated log-probability, assuming
reduction == 'sum'.
"""
args = (
predictor,
codebook_indexes,
self.linear1.weight,
self.linear1.bias,
self.codebook_embedding.weight,
self.linear2_weight,
self.linear2b_weight,
self.linear2_bias,
self.ignore_index,
self.reduction,
)
if self.checkpoint:
return checkpoint(joint_codebook_loss, *args)
else:
return joint_codebook_loss(*args)

View File

@ -0,0 +1,187 @@
import torch
from torch import nn
from torch import Tensor
from typing import Tuple
class JointCodebookPredictor(nn.Module):
"""
This module predicts a group of codebook indexes from a vector. The idea is that
you have a number of codebooks (probably jointly trained), from class Quantizer,
and you want to predict the probabilities of the codebook entries based on some
predictor that you are training.
The simplest thing would be to project the vector using nn.Linear, then
reshape and use logsoftmax to normalize the probabilities within each group,
then compute the likelihood. However, this has a constraint that all the
codebooks are predicted independently of each other. This module allows you
to predict them jointly, by regressing each codebook on all previous codebooks.
This is done with a nonlinearity in which the previous codebook entries are combined
with the input predictor vector, so that the regression is not purely
linear.
Args:
predictor_dim: the number of features that we use to predict the codebook
indexes, e.g. 2048 (will depend on your model).
num_codebooks: the number of codebooks that you are predicting;
will likely be the same as the bytes_per_frame given to the
QuantizerTrainer that you used to train the Quantizer you
are predicting.
codebook_size: number of entries per codebook (often 256)
self_prediction: you can set this to false to enable prediction of
codebooks by earlier-numbered codebooks
hidden_dim: the hidden dimension per codebook (we use a 1-hidden-layer
network, with a ReLU and then batchnorm).
"""
def __init__(
self,
predictor_dim: int,
num_codebooks: int,
codebook_size: int = 256,
self_prediction: bool = True,
hidden_dim: int = 384,
):
super(JointCodebookPredictor, self).__init__()
self.num_codebooks = num_codebooks
self.codebook_size = codebook_size
self.hidden_dim = hidden_dim
self.linear1 = nn.Linear(predictor_dim, num_codebooks * hidden_dim)
if self_prediction:
linear_self_out_dim = (num_codebooks - 1) * hidden_dim
linear_self_in_dim = (num_codebooks - 1) * codebook_size
self.linear_self = nn.Parameter(
torch.randn(linear_self_out_dim, linear_self_in_dim)
* (linear_self_in_dim ** -0.5)
)
# num_codebooks == 3 and hidden_dim == 2 and codebook_size == 2,
# the expression below has the value:
# tensor([[ True, True, False, False],
# [ True, True, False, False],
# [ True, True, True, True],
# [ True, True, True, True]])
self.register_buffer(
"linear_self_mask",
(
(torch.arange(linear_self_out_dim) // hidden_dim).unsqueeze(
1
)
>= (
torch.arange(linear_self_in_dim) // codebook_size
).unsqueeze(0)
),
)
else:
self.register_parameter("linear_self", None)
self.register_buffer("linear_self_mask", None)
self.norm = nn.BatchNorm1d(num_codebooks * hidden_dim)
self.linear2 = nn.Parameter(
torch.randn(num_codebooks, codebook_size, hidden_dim)
* (hidden_dim ** -0.5)
)
self.bias2 = nn.Parameter(torch.zeros(num_codebooks, codebook_size))
def forward(
self, predictor: Tensor, codebook_indexes: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Forward function.
Args:
predictor: a Tensor of some real type, with shape (*, predictor_dim).
codebook_indexes: a Tensor of integers, of shape (*, num_codebooks),
where the '*' should be the same as for `predictor`. It will be
converted to type torch.int64. Should contain indexes of codebook
entries, in {0..codebook_size-1},
or negative values which will be interpreted as "no codebook index here"
(e.g. due to padding); we assume that each frame will either have
all-negative or all-nonnegative indexes, meaning that (codebook_indexes >= 0)
should not vary as you change the last index into it.
Returns: total_logprob, total_count, where:
total_logprob: a scalar Tensor, containing the total log-probability of all
the nonnegative codebook indexes,
total_count: a scalar Tensor containing the total count of nonzero frames,
satisfying total_count <= codebook_indexes.numel() / num_groups
"""
codebook_indexes = codebook_indexes.to(torch.int64)
# import pdb; pdb.set_trace()
assert list(predictor.shape[:-1]) == list(codebook_indexes.shape[:-1])
assert codebook_indexes.shape[-1] == self.num_codebooks
tot_codebook_dim = self.num_codebooks * self.codebook_size
common_shape = list(predictor.shape[:-1])
codebook_one_hot = torch.zeros(
*common_shape,
self.num_codebooks,
self.codebook_size,
device=predictor.device
)
codebook_mask = (codebook_indexes >= 0).unsqueeze(
-1
) # (*, num_codebooks, 1)
codebook_indexes_floor = torch.clamp(codebook_indexes, min=0).unsqueeze(
-1
) # (*, num_codebooks, 1)
codebook_one_hot.scatter_(
dim=-1,
index=codebook_indexes_floor,
src=codebook_mask.to(torch.float32),
)
codebook_one_hot = codebook_one_hot.reshape(
*common_shape, tot_codebook_dim
) # (*, tot_codebook_dim)
hidden = self.linear1(predictor)
if self.linear_self is not None:
codebook_one_hot_part = torch.narrow(
codebook_one_hot, -1, 0, tot_codebook_dim - self.codebook_size
)
self_predictor = torch.matmul(
codebook_one_hot_part,
(self.linear_self * self.linear_self_mask).transpose(0, 1),
)
# add the 'self_predictor' term to all but the 1st
# block of "hidden".
hidden_part = torch.narrow(
hidden,
-1,
self.hidden_dim,
self.hidden_dim * (self.num_codebooks - 1),
)
hidden_part += self_predictor
hidden = nn.functional.relu(hidden)
hidden = hidden.reshape(-1, self.hidden_dim * self.num_codebooks)
hidden = self.norm(hidden)
hidden = hidden.reshape(
*common_shape, self.num_codebooks, self.hidden_dim
) # (*, num_codebooks, hidden_dim)
logprobs = (
torch.matmul(
hidden.unsqueeze(-2), self.linear2.transpose(1, 2)
).squeeze(-2)
+ self.bias2
)
# logprobs: (*, num_codebooks, codebook_size)
logprobs = logprobs.log_softmax(dim=-1)
logprobs = logprobs.reshape(
*common_shape, self.num_codebooks * self.codebook_size
)
tot_logprob = torch.dot(
logprobs.reshape(-1), codebook_one_hot.reshape(-1)
)
assert tot_logprob <= 0.0
# the select() part is to select only the mask for one of the codebooks (they should
# all be the same), as we want the total number of frames.
tot_count = codebook_mask.select(dim=-2, index=0).sum()
return (tot_logprob, tot_count)

View File

@ -0,0 +1,923 @@
import binascii
import h5py
import math
import numpy as np
import os
import time
import torch
import random
import logging
from torch import nn
from torch import Tensor
from typing import Tuple
class Quantizer(nn.Module):
# what this is implementing appears to be referred to as direct-sum codebooks in the scientific literature.
# see also residual vector quantization, or multistage VQ, although this method jointly optimizes
# the codebook entries so there is no order.
def __init__(self, dim: int, codebook_size: int, num_codebooks: int):
"""
Trainable quantizer that encodes a vector into a sequence of integers (corresponding
to multiple separate codebooks), aiming to get the least possible expected squared
difference.
"""
super(Quantizer, self).__init__()
self.dim = dim
self.codebook_size = codebook_size
self.num_codebooks = num_codebooks
def is_power_of_two(n: int) -> bool:
return (n & (n - 1) == 0) and n != 0
assert is_power_of_two(codebook_size)
assert is_power_of_two(num_codebooks)
self.to_logits = nn.Linear(dim, codebook_size * num_codebooks)
self.logits_scale = 4
# self.centers: (num_codebooks, codebook_size, dim)
self.centers = nn.Parameter(
self.to_logits.weight.detach()
.clone()
.reshape(num_codebooks, codebook_size, dim)
)
# We give each Quantizer a unique 8-digit hex identifier, which we'll use to reduce the
# probability of mixing up the outputs of different Quantizers.
# It is saved as a buffer, as well as a string, so that it will be loaded
# from disk when we use load_state_dict().
id_bytes = binascii.b2a_hex(
os.urandom(4)
) # random hex string, e.g. b'585ce3cf'
self.id_str = id_bytes.decode("utf-8")
self.register_buffer(
"id_buf", torch.tensor(list(id_bytes), dtype=torch.uint8)
)
def load_state_dict(self, *args, **kwargs):
super(Quantizer, self).load_state_dict(*args, **kwargs)
self.id_str = bytes(self.id_buf.tolist()).decode("utf-8")
def get_id(self) -> str:
return self.id_str
def show_init_invocation(self) -> str:
return f"quantization.Quantizer(dim={self.dim}, codebook_size={self.codebook_size}, num_codebooks={self.num_codebooks})"
def get_data_mean(self) -> Tensor:
"""
Return an approximate expression for the mean of the training data, as a tensor
of shape (dim,). This is useful for diagnostics. It is detached from gradient,
to avoid this affecting the optimization.
The expression we use assumes balanced codebook probabilities, which is true
in practice (as long as index_entropy_loss in training is fairly small).
"""
return self.centers.mean(dim=1).sum(dim=0).detach()
def get_product_quantizer(self) -> "Quantizer":
"""
Returns a Quantizer object with codebook_size = self.codebook_size**2 and
num_codebooks = self.num_codebooks//2, initialized so that each codebook
in the result is formed from pairs of codebooks in this object.
"""
new_codebook_size = self.codebook_size ** 2
new_num_codebooks = self.num_codebooks // 2
ans = Quantizer(self.dim, new_codebook_size, new_num_codebooks).to(
self.centers.device
)
ans.apply_mask = False
with torch.no_grad():
for c_out in range(new_num_codebooks):
c_in1 = 2 * c_out
c_in2 = 2 * c_out + 1
for k_in1 in range(self.codebook_size):
row_in1 = self.codebook_size * c_in1 + k_in1
for k_in2 in range(self.codebook_size):
row_in2 = self.codebook_size * c_in2 + k_in2
k_out = k_in1 * self.codebook_size + k_in2
row_out = new_codebook_size * c_out + k_out
ans.to_logits.weight[row_out, :] = (
self.to_logits.weight[row_in1]
+ self.to_logits.weight[row_in2]
)
ans.to_logits.bias[row_out] = (
self.to_logits.bias[row_in1]
+ self.to_logits.bias[row_in2]
)
ans.centers[c_out, k_out, :] = (
self.centers[c_in1, k_in1, :]
+ self.centers[c_in2, k_in2, :]
)
return ans
def decode(self, indexes: Tensor) -> Tensor:
"""
Does the (approximate) inverse of _compute_indexes(): constructs from `indexes` the
corresponding approximated tensor.
Args:
indexes:
May be an integer tensor of shape (*, self.num_codebooks), with entries
in {0..self.num_codebooks-1}
May also contain multiple codebook entries combined into one integer, as
done by encode() with as_bytes==True; in this case the last dim
might be self.num_codebooks/2 or self.num_codebooks/4.
Returns: a tensor of shape (*, self.dim), consisting of the sum of the specified
cluster centers.
"""
orig_shape = indexes.shape
indexes = indexes.reshape(-1, indexes.shape[-1])
indexes = self._maybe_separate_indexes(indexes).to(dtype=torch.int64)
assert indexes.ndim == 2
B = indexes.shape[0]
# indexes_expanded: (num_codebooks, B, dim)
indexes_expanded = (
indexes.transpose(0, 1)
.contiguous()
.unsqueeze(-1)
.expand(self.num_codebooks, B, self.dim)
)
# self.centers: (num_codebooks, codebook_size, dim)
# chosen_codebooks: (num_codebooks, B, dim).
chosen_codebooks = torch.gather(
self.centers, dim=1, index=indexes_expanded
)
# x_approx: (B, dim), this is the sum of the chosen rows of `to_output`
# corresponding to the chosen codebook entries, this would correspond to
# the approximated x.
x_approx = chosen_codebooks.sum(dim=0)
return x_approx.reshape(*orig_shape[:-1], self.dim)
def compute_codebook_correlations(self) -> Tensor:
"""
Return a Tensor of shape (self.num_codebooks, self.num_codebooks)
with values >= 0, which are greater if a pair of codebooks more strongly
shares a subspace. This is for diagnostic purposes.
These correlations are computed by:
- subtracting the mean value from each codebook
- creating an uncentered variance S_i for each codebook i
- computing, for each pair of codebooks i and j, c_{ij} = tr(S_i S_j)
- returning c_{ij} / sqrt(c_{ii} c_{ij}), which is a symmetric
matrix with values in [0,1]
"""
centers = self.centers.detach()
codebook_means = centers.mean(
dim=1, keepdim=True
) # (num_codebooks, 1, dim)
centers = centers - codebook_means # make each codebook zero-mean.
# variances: (num_codebooks, dim, dim)
variances = torch.matmul(centers.transpose(1, 2), centers)
# variances_flat: (num_codebooks, dim * dim)
variances_flat = variances.reshape(
self.num_codebooks, self.dim * self.dim
)
# cross_variances: (num_codebooks, num_codebooks), should be all positive
# (interpret these as tr(0.5*(V1 * V2 + V2 * V1)) == tr(V1 * V2) ==
# the sum of products of corresponding elements (for this, we use the fact
# that V1 and V2 are both symmetric).
cross_variances = torch.matmul(variances_flat, variances_flat.t())
normalizer = cross_variances.diag() ** -0.5
normalizer = normalizer.unsqueeze(0) * normalizer.unsqueeze(1)
return cross_variances * normalizer
def compute_loss(self, x: Tensor, refine_indexes_iters: int = 0) -> Tensor:
"""
Compute various parts of the loss function.
Args:
x: the Tensor to quantize, of shape (*, dim)
refine_indexes_iters: a number >= 0: the number of iterations to refine
the indexes from their initial value.
Returns: (rel_reconstruction_loss, logprob_loss, entropy_loss, index_entropy_loss), where
rel_reconstruction_loss: a scalar torch.Tensor containing the relative sum-squared
reconstruction loss, based on the indexes chosen after `refine_indexes_iters`
iterations of refinement after the argmax of the logits. This loss is
is the sum-squared of (x - reconstructed_x) / (sum-squared of x-x_mean), which
for already-trained models will be between 0 and 1, but could be greater than 1
at the start of training.
logprob_loss: the negative average logprob of the selected classes (i.e. those
selected after refine_indexes_iters of refinement). This is added to the
loss function, so we can select reasonable classes before refining the indexes.
logits_entropy_loss: the class entropy loss, from the logits, which approaches
zero when all classes of all codebooks are equi-probable (in the logits output).
index_entropy_loss: the class entropy loss, from the computed indexes, which approaches
zero when all classes of all codebooks are equi-probable (in the indexes output).
Not differentiable but useful for diagnostics.
"""
x = x.reshape(-1, self.dim)
indexes = self._compute_indexes(x, refine_indexes_iters)
x_approx = self.decode(indexes)
# tot_error: (B, dim), the error of the approximated vs. real x.
tot_error = x_approx - x
rel_reconstruction_loss = (tot_error ** 2).sum() / (
((x - self.get_data_mean()) ** 2).sum() + 1.0e-20
)
# Get logprob loss and class-entropy loss
# wasteful.. already computed logits..
logits = self._logits(x).reshape(
-1, self.num_codebooks, self.codebook_size
)
logits = logits.log_softmax(dim=2)
# chosen_logits: (B, num_codebooks, 1)
chosen_logits = torch.gather(logits, dim=2, index=indexes.unsqueeze(2))
logprob_loss = -chosen_logits.mean()
# class_entropy
B = x.shape[0]
counts = torch.zeros(
B, self.num_codebooks, self.codebook_size, device=x.device
)
ones = torch.ones(1, 1, 1, device=x.device).expand(
B, self.num_codebooks, self.codebook_size
)
counts.scatter_(src=ones, dim=2, index=indexes.unsqueeze(2))
avg_counts = counts.mean(dim=0) + 1.0e-20
index_entropy = -(avg_counts * avg_counts.log()).sum(dim=1).mean()
probs = logits.exp().mean(dim=0) + 1.0e-20
logits_entropy = -(probs * probs.log()).sum(dim=1).mean()
ref_entropy = math.log(self.codebook_size)
logits_entropy_loss = (ref_entropy - logits_entropy) / ref_entropy
index_entropy_loss = (ref_entropy - index_entropy) / ref_entropy
return (
rel_reconstruction_loss,
logprob_loss,
logits_entropy_loss,
index_entropy_loss,
)
def encode(
self, x: Tensor, refine_indexes_iters: int = 5, as_bytes: bool = True
) -> Tensor:
"""
Compute the quantized output, that can be used to reconstruct x.
Args:
x: the Tensor to quantize, of shape (*, dim)
refine_indexes_iters: a number >= 0: the number of iterations to refine
the indexes from their initial value.
as_bytes: if True, the quantized output will be returned as a byte
array, combining as many codes as possible into each bytes
codebook_size <= 16.
Returns: if as_bytes == False, a torch.LongTensor of shape (*, num_codebooks);
if as_bytes == True, a returns a Tensor of dtype=torch.uint8, of shape
(*, num_codebooks/n), where n==4 if codebook_size <= 14; or
2 if codebook_size <= 16, else 1.
"""
x_reshaped = x.reshape(-1, self.dim)
indexes = self._compute_indexes(x_reshaped, refine_indexes_iters)
if as_bytes:
codebook_size = self.codebook_size
while codebook_size ** 2 <= 256:
indexes = indexes[:, ::2] + codebook_size * indexes[:, 1::2]
codebook_size = codebook_size ** 2
assert codebook_size <= 256
indexes = indexes.to(torch.uint8)
return indexes.reshape(*x.shape[:-1], -1)
def _logits(self, x: Tensor) -> Tensor:
return self.to_logits(x) * self.logits_scale
def _compute_indexes(
self, x: Tensor, refine_indexes_iters: int = 3
) -> Tensor:
"""
Deterministically compute the indexes that encode the tensor x.
Args:
x: the Tensor to quantize, of shape (B, dim)
refine_indexes_iters: a number >= 0: the number of iterations to refine
the indexes from their initial value.
Returns: returns a torch.LongTensor of shape (B, num_codebooks),
with entries in {0..codebook_size-1}
"""
assert x.ndim == 2 and x.shape[1] == self.dim
B = x.shape[0]
x_reshaped = x.reshape(-1, self.dim)
B = x_reshaped.shape[0]
logits = self._logits(x_reshaped)
logits = logits.reshape(B, self.num_codebooks, self.codebook_size)
# indexes: (B, self.num_codebooks)
indexes = torch.argmax(logits, dim=-1)
for i in range(refine_indexes_iters):
indexes = self._refine_indexes(x_reshaped, indexes)
assert indexes.ndim == 2
return indexes.reshape(*x.shape[:-1], self.num_codebooks)
def _refine_indexes(self, x: Tensor, indexes: Tensor) -> Tensor:
"""
Refine choices of indexes, minimizing sum-squared loss. Note, this is not guaranteed
not not increase the sum-squared loss, but works OK in practice.
Args:
x: A Tensor of shape (B, self.dim) to be approximated.
indexes: A Tensor of integer type, of shape (B, self.num_codebooks),
that contains elements in {0..self.codebook_size-1}
i: the iteration of refinement (may affect the groups we choose
to optimize)
Returns: A tensor of indexes of shape (B, self.num_codebooks) that
will hopefully reduce the error w.r.t. x, better or at least no worse
than `indexes`. This algorithm is not exact, but if the codebooks are
fairly orthogonal it should work fine. If they are not fairly orthogonal
it may not optimize well, but hopefully the codebooks will then learn
to be more orthogonal.
"""
B = indexes.shape[0]
# indexes_expanded has shape (B, self.num_codebooks, 1, self.dim)
indexes_expanded = (
indexes.unsqueeze(-1)
.unsqueeze(-1)
.expand(B, self.num_codebooks, 1, self.dim)
)
# all_centers: (1, num_codebooks, codebook_size, dim)
all_centers = self.centers.unsqueeze(0)
# centers_expanded has shape (B, self.num_codebooks, self.codebook_size, self.dim)
centers_expanded = all_centers.expand(
B, self.num_codebooks, self.codebook_size, self.dim
)
# old_centers: (B, self.num_codebooks, 1, self.dim)
# centers with the "indexes" as passed-in.
old_centers = torch.gather(
centers_expanded, dim=2, index=indexes_expanded
)
# x_err is of shape (B, 1, 1, self.dim), it is the old value of (x_approx - x)
x_err = old_centers.sum(dim=1, keepdim=True) - x.unsqueeze(1).unsqueeze(
2
)
# The algorithm below is going to be iterative, where at each stage we
# have N K-way choices, with each choice corresponding to L codebook indexes.
# Initially N == num_codebooks, K == codebook_size, L == 1,
# and on the iterations of the algorithm we either:
# - terminate by finding the best choice, if currently N == 1
# - combine pairs of choices, so that N := N//2, K := K ** 2, L *= 2
# - reduce K to K_cutoff by sorting and taking the K_cutoff best possibilities
# for each choice. K_cutoff is a power of 2 that starts at 8 or 16
# and doubles every 2 iterations to keep the work per iteration
# fairly constant.
# At all points in the algorithm we maintain cur_sumsq and (conceptually)
# cur_deltas (however in some parts cur_deltas is not instantiated, see
# gather_deltas).
#
# cur_indexes: (B, N, K, L), initially (B, num_codebooks, codebook_size, 1),
# gives the codebook indexes corresponding to the k'th value of the n'th
# choice. Initially this is just an arange expression but from the 1st
# iter of the algorithm it changes to something nontrivial.
#
# cur_sumsq: (B, N, K), is the sum-squared error of x versus its predicted value
# from the codebooks, if we were to
# make the n'th choice with value k without making any of the other N-1 choices, i.e.
# if we were to leave the other choices at the value we had at input.
# Specifically, it is always supposed to equal the value of
# ((x_err + cur_deltas)**2).sum(dim=-1)
# .. but we keep it around separately because it enables an optimization.
#
# cur_deltas: (B, N, K, dim), is the change in x_err (with x_err =
# x_approx - x and x_approx being a sum of codebook indexes) if we were
# to make the n'th choice with value k without making any of the other
# N-1 choices.
# At the current point, i.e. at the start of the algorithm,
# cur_deltas[b][n][k] says "what would be the change in x_err if we
# were to replace the current choice of the n'th codebook entry-- i.e.
# the choice reflected in `indexes`-- with value k? [In general,
# cur_deltas[b][n][k] refers not directly to a codebook indexes, but
# to an indexes into `cur_indexes` which corresponds to the sequence/combination
# of codebook indexes that are stored in cur_indexes[b][n][k].
# cur_deltas represents the change in x_err from making each choice (while
# leaving all the other choices un-made by just keeping the passed-in/old
# indexes).
# cur_deltas: (B, N, K, dim),
N = self.num_codebooks
K = self.codebook_size
L = 1 # L is the number of codebooks covered by each choice.
# Conceptually we could do:
# cur_deltas = all_centers - old_centers # (B, N, K, dim)
# ... however actually we won't be instantiating cur_deltas at this stage of the
# algorithm.
dim = self.dim
# cur_indexes is the codebook indexes corresponding to 'cur_deltas'.
cur_indexes = (
torch.arange(K, device=x.device)
.reshape(1, 1, K, 1)
.expand(B, N, K, L)
)
if True:
# compute cur_sumsq using an efficient approach
x_err_sumsq = (x_err ** 2).sum(dim=-1) # (B, 1, 1)
x_remaining = (
x_err - old_centers
) # (B, num_codebooks, 1, dim): the x_err after subtracting
# each of the codebooks; if we add back to this any given
# codebook vector (from all_centers), we'll get the error
# if we were to
# choose that codebook entry instead of the one actually chosen.
x_remaining_sumsq = (x_remaining ** 2).sum(
dim=-1
) # (B, num_codebooks, 1)
# all_centers_sumsq is the sumsq of all the centers..
all_centers_sumsq = (all_centers ** 2).sum(
dim=-1
) # (1, num_codebooks, codebook_size)
cross_sum = torch.matmul(
all_centers, # (1, num_codebooks, codebook_size, dim)
x_remaining.permute(2, 1, 3, 0), # (1, num_codebooks, dim, B)
) # (1, num_codebooks, codebook_size, B)
cross_sum = cross_sum.squeeze(0).permute(
2, 0, 1
) # (B, num_codebooks, codebook_size)
# (B, num_codebooks, codebook_size); interpret as (B, N, K)
cur_sumsq = x_remaining_sumsq + all_centers_sumsq + 2 * cross_sum
assert cur_sumsq.shape == (B, N, K)
# gather_deltas (which will be re-defined below) is a lambda from
# `this_indexes`, a LongTensor of shape (B, N, new_K, 1) [which
# at the current iteration would equal (B, num_codebooks, new_K, 1)]
# with elements in
# {0..K-1} [i.e. 0..codebook_size-1], to the new "cur_deltas".
# It is provided as a workaround in
# case we did not physically instantiate cur_deltas on this iteration.
# In general cur_deltas is supposed to represent "change in encoded
# value" if we were to make a particular modified index choice, leaving
# all other choices as they were on entry.
# gather_deltas is supposed to be a lambda from this_indexes to the
# something equivalent to following expression (if cur_deltas had actually
# existed):
# torch.gather(input=cur_deltas, dim=2, index=this_indexes.expand(B, N, new_K, dim))
gather_deltas = lambda this_indexes: (
torch.gather(
input=all_centers.expand(B, N, K, dim),
dim=2,
index=this_indexes.expand(B, N, -1, dim),
)
- old_centers
)
else:
cur_deltas = all_centers - old_centers # (B, N, K, dim)
## cur_sumsq: (B, N, K), equivalent to: ((x_err + cur_deltas)**2).sum(dim=-1)
## We really want batched vector-vector product her, which torch does not
## explicitly support, so we use a matrix multiplication with 1x1 output.
modified_err = x_err + cur_deltas # (B, N, K, dim)
cur_sumsq = (
torch.matmul(
modified_err.unsqueeze(-2), modified_err.unsqueeze(-1)
)
.squeeze(-1)
.squeeze(-1)
)
gather_deltas = None
# x_err_sumsq: (B, 1, 1), is the sum-squared of x_err; we'll need it in the loop.
x_err_sumsq = (x_err ** 2).sum(dim=-1)
K_cutoff_base = 8 if self.codebook_size <= 16 else 16
def get_K_cutoff():
# Every time L increases by 4, we double K_cutoff. This keeps the
# work per iteration roughly constant, as it's linear in 1/L
# and in K_cutoff**2.
K_cutoff, l = K_cutoff_base, L
while l >= 4:
l /= 4
K_cutoff *= 2
return min(K_cutoff, 128)
while True:
K_cutoff = get_K_cutoff()
if N == 1 and K == 1:
return cur_indexes.squeeze(2).squeeze(
1
) # (B, L) == (B, num_codebooks)
elif K > K_cutoff or N == 1:
# Sort the options for each choice, and reduce K.
# this_indexes: (B, N, K); elements in {0..K-1}. These
# are sorted from best (lowest) to worst.
_, this_indexes = torch.sort(cur_sumsq, dim=2)
new_K = 1 if N == 1 else K_cutoff
this_indexes = this_indexes[:, :, :new_K]
cur_sumsq = torch.gather(
input=cur_sumsq, dim=2, index=this_indexes
)
this_indexes = this_indexes.unsqueeze(-1)
# cur_indexes is (B, N, new_K, L), but with only the chosen
# indexes kept.
cur_indexes = torch.gather(
input=cur_indexes,
dim=2,
index=this_indexes.expand(B, N, new_K, L),
)
if gather_deltas is None:
# also sort cur_deltas in the same way
cur_deltas = torch.gather(
input=cur_deltas,
dim=2,
index=this_indexes.expand(B, N, new_K, dim),
)
else:
# gather_deltas should be a lambda from:
# this_indexes: a LongTensor of shape (B, N, new_K, 1) containing elements in {0..K-1}
# to the new "deltas" which should be of shape
# (B, N, new_K, dim)
# representing the difference from the baseline "x_offset" if we choose this
# index for this codebook or range of codebooks, leaving other choices
# as they were at entry to this function.
cur_deltas = gather_deltas(this_indexes)
gather_deltas = None
K = new_K
else:
# Combine pairs of choices. We know that N > 1.
even_deltas = cur_deltas[:, 0::2, :, :]
odd_deltas = cur_deltas[:, 1::2, :, :]
even_indexes = cur_indexes[:, 0::2, :, :]
odd_indexes = cur_indexes[:, 1::2, :, :]
even_sumsq = cur_sumsq[:, 0::2, :]
odd_sumsq = cur_sumsq[:, 1::2, :]
new_N = N // 2
new_K = K ** 2
new_L = L * 2
even_indexes = (
even_indexes.unsqueeze(3)
.expand(B, new_N, K, K, L)
.reshape(B, new_N, new_K, L)
)
odd_indexes = (
odd_indexes.unsqueeze(2)
.expand(B, new_N, K, K, L)
.reshape(B, new_N, new_K, L)
)
cur_indexes = torch.cat((even_indexes, odd_indexes), dim=3)
even_sumsq = even_sumsq.unsqueeze(3) # (B, new_N, K, 1)
odd_sumsq = odd_sumsq.unsqueeze(2) # (B, new_N, 1, K)
# cur_sumsq below is a partial version, we have to add another term.
# The new version of cur_sumsq that we want can be expressed as:
# ((a + b + c)**2).sum(dim=-1),
# where a = x_err, b == even_deltas, c == odd_deltas. Ignoring the summation, we
# can write this as:
# a^2 + b^2 + c^2 + 2ab + 2ac + 2bc.
# We can rearrange this as:
# (a^2 + b^2 + 2ab) + (a^2 + c^2 + 2ac) - a^2 + 2bc,
# which is the same as
# even_sumsq + odd_sumsq - x_err_sumsq + 2bc,
# where 2bc is a certain matrix product of odd_deltas and even_deltas.
cur_sumsq = (
(even_sumsq + odd_sumsq).reshape(B, new_N, new_K)
- x_err_sumsq
+ 2
* torch.matmul(
even_deltas, odd_deltas.transpose(2, 3)
).reshape(B, new_N, new_K)
)
saved_K = K
gather_deltas = lambda this_indexes: (
torch.gather(
input=even_deltas,
dim=2,
index=(this_indexes // saved_K).expand(
*this_indexes.shape[:-1], dim
),
)
+ torch.gather(
input=odd_deltas,
dim=2,
index=(this_indexes % saved_K).expand(
*this_indexes.shape[:-1], dim
),
)
)
cur_deltas = None # Unset it, it is now invalid, but we'll reconstruct it using gather_deltas.
N, K, L = new_N, new_K, new_L
assert cur_indexes.shape == (B, N, K, L)
assert cur_sumsq.shape == (B, N, K)
def _maybe_separate_indexes(self, indexes: Tensor) -> Tensor:
"""
This reverses the process done in encode() if as_bytes==True, which combines
multiple codebook entries into a single byte if self.codebook_size is small
enough.
Args:
indexes: an integer tensor of shape (B, n) where n divides
self.num_codebooks
Returns: a tensor of the same type as `indexes`, of shape (B,
self.num_codebooks)
"""
B = indexes.shape[0]
if indexes.shape[-1] != self.num_codebooks:
n = indexes.shape[-1]
num_repeats = self.num_codebooks // n
assert (
num_repeats in [2, 4, 8, 16]
and self.num_codebooks == n * num_repeats
)
indexes = indexes.unsqueeze(2).expand(B, n, num_repeats)
size = self.codebook_size
indexes = (
indexes
// (size ** torch.arange(num_repeats, device=indexes.device))
) % size
indexes = indexes.reshape(B, self.num_codebooks)
assert indexes.shape == (B, self.num_codebooks)
return indexes
class QuantizerTrainer(object):
def __init__(
self,
dim: int,
bytes_per_frame: int,
device: torch.device,
phase_one_iters: int = 10000,
phase_two_iters: int = 10000,
lr: float = 0.005,
):
"""
Args:
dim: The feature dimension we are trying to quantize, e.g. 512
bytes_per_frame: The number of bytes to use to quantize each vector of
`dim` values.
device: The device to use for training
phase_one_iters: The number of iterations to use for the first
phase of training (with codebook_size=16); after this we
will convert to have codebook_size=256. These parameters were
tuned with a batch size of 600: if your batch size (in frames)
is smaller than this you may benefit from a larger phase_one_iters and a
smaller learning rate.
[Also, note: phase_one_iters should be larger for larger dims;
for dim=256 and batch_size=600, 10k was enough, but for
dim=512 and batch_size=600, 20k was better.
phase_two_iters: The number of iterations to use for the second
phase of training (with codebook_size=256)
lr: The initial learning rate.
This object trains a Quantizer. You can use it as follows:
trainer = QuantizerTrainer(...)
while not trainer.done():
# let x be some tensor of shape (*, dim), that you will train on
# (should not be the same on each minibatch)
trainer.step(x)
quantizer = trainer.get_quantizer()
"""
super(QuantizerTrainer, self).__init__()
assert bytes_per_frame in [1, 2, 4, 8, 16, 32]
# We'll initially train with codebook_size=16 and
# num_codebooks=bytes_per_frame * 2, then after `phase_one_iters` of
# training will multiply pairs of codebooks so codebook_size=256 and
# num_codebooks=bytes_per_frame
self.phase_one_iters = phase_one_iters
self.phase_two_iters = phase_two_iters
self.cur_iter = 0
self.lr = lr
self.two_iter_prob = 0.5
self.quantizer = Quantizer(
dim=dim, codebook_size=16, num_codebooks=bytes_per_frame * 2
).to(device)
self.start_time = time.time()
self._init_optimizer()
def done(self) -> bool:
ans = self.cur_iter > self.phase_one_iters + self.phase_two_iters
if ans:
elapsed_time = time.time() - self.start_time
logging.info(
f"Elapsed time, training model of dim={self.quantizer.dim}, num_codebooks={self.quantizer.num_codebooks}, "
f"codebook_size={self.quantizer.codebook_size}, is: {elapsed_time:.2f} seconds."
)
return ans
def step(self, x: torch.Tensor) -> None:
"""
Does one step of training. You must call this at least 2*phase_one_iters
iterations.
Args:
x: a Tensor of shape (*, dim) containing the frames of data we are
trying to accurately encode.
"""
x = x.reshape(-1, self.quantizer.dim)
num_iters = 2 if random.random() < self.two_iter_prob else 1
(
reconstruction_loss,
logprob_loss,
logits_entropy_loss,
index_entropy_loss,
) = self.quantizer.compute_loss(x, num_iters)
if self.cur_iter % 200 == 0:
det_losses = [
float("%.3f" % self.quantizer.compute_loss(x, j)[0].item())
for j in range(6)
]
phase = 1 if self.cur_iter <= self.phase_one_iters else 2
i = (
self.cur_iter - self.phase_one_iters
if phase > 1
else self.cur_iter
)
# Caution: python's logging level is logging.ERROR by default. To make the following
# be printed, you may have to do:
# import logging
# logging.getLogger().setLevel(logging.INFO)
# before using this code.
logging.info(
f"phase={phase}/2, iter={i}, "
f"dim,nc,csz={self.quantizer.dim},{self.quantizer.num_codebooks},{self.quantizer.codebook_size}, "
f"loss_per_iter={det_losses}, "
f"logprob_loss={logprob_loss.item():.3f}, "
f"logits_entropy_loss={logits_entropy_loss.item():.3f}, "
f"index_entropy_loss={index_entropy_loss.item():.3f}"
)
if self.cur_iter % 2000 == 0 and self.cur_iter > 0:
correlations = self.quantizer.compute_codebook_correlations()
logging.info(f"correlations = {correlations}")
# We did not find it necessary to use entropy_scale -- the
# logits_entropy_loss and index_entropy_loss are less than 0.01 even
# with entropy_scale == 0 -- but we are putting a nonzero value on
# entropy_scale just out of an abundance of caution, in case an unusual
# data distribution might cause problems in the future.
entropy_scale = 0.01
# About the losses:
# - reconstruction_loss >= 0; it equals 0 when reconstruction is exact.
# This is the main loss function, used to train quantizer.centers
# - logprob_loss trains only quantizer.to_logits, which predicts the
# indexes after refinement, so we can initialize them well; it does
# not affect the cluster centers.
# - logits_entropy_loss is currently not used for training, since we
# set entropy_scale = 0 above. It would affect only to_logits, if
# used. The intention was that this might solve problem with
# cluster centers having very uneven probabilities of being chosen
# (it would do this by biasing the initial choice, relying on
# the inexactness of the search). In our experiments,
# logits entropy_loss and index_entropy_loss both end up
# less than 0.05, so this does not seem to be a problem in practice,
# but it might be a problem if, say, the inputs had a very tiny scale,
# so we are keeping the code around.
# - index_entropy_loss is not differentiable; we have
# added it only for diagnostic purposes. It reflects the entropy of
# the distribution over classes, after refining the cluster indexes.
# It was computed just in case regularizing logits_entropy_loss was
# not enough to affect the final distribution over cluster centers,
# so we could diagnose the problem;
# but we found no problem in practice.
#
tot_loss = (
reconstruction_loss
+ logprob_loss
+ logits_entropy_loss * entropy_scale
)
# We want to maximize frame_entropy if it is less than frame_entropy_cutoff.
# tot_loss -= torch.minimum(self.frame_entropy_cutoff,
# frame_entropy)
tot_loss.backward()
self.optim.step()
self.optim.zero_grad()
self.scheduler.step()
if self.cur_iter == self.phase_one_iters:
self._begin_second_phase()
self.cur_iter += 1
def _init_optimizer(self):
self.optim = torch.optim.Adam(
self.quantizer.parameters(),
lr=self.lr,
betas=(0.9, 0.98),
eps=1e-9,
weight_decay=1.0e-06,
)
self.scheduler = torch.optim.lr_scheduler.StepLR(
self.optim,
step_size=(
self.phase_one_iters
if self.cur_iter == 0
else self.phase_two_iters
)
/ 4,
gamma=0.5,
)
def _begin_second_phase(self):
"""
This is to be called exactly once, when self.cur_iter reaches self.phase_one_iters
"""
self.quantizer = self.quantizer.get_product_quantizer()
self.lr *= 0.5
self._init_optimizer()
def get_quantizer(self) -> Quantizer:
assert self.cur_iter >= self.phase_one_iters + self.phase_two_iters
return self.quantizer
def read_hdf5_data(filename: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Reads the hdf5 archive in the file with name 'filename' into a single
numpy array of size (tot_frames, dim), shuffles the frames, and returns it
as a numpy array.
The type will be the same as it was in the archive (e.g. float16).
Args:
filename: the name of the filename of your hdf5 archive. It should
have been created using code similar to the code in test_write_hdf5.py,
e.g. something like:
hf = h5py.File(filename, 'w')
for i in range(...):
# get x as some numpy array of type np.float16, and shape (*, dim)
# the name does not actually matter,
# except that they should be distinct.
hf.create_dataset(f'dataset_{i}', data=x)
Returns (train, valid), where:
train: a torch.Tensor of shape (tot_train_frames, dim), on CPU, with
dtype=torch.float16, with shuffled rows.
valid: a torch.Tensor of shape (tot_valid_frames, dim), on CPU, with
dtype=torch.float16, with shuffled rows (these are distinct
frames from those in `train`, but may derive from diffrent
rows of the same original tensors.)
Caution: you should set the logger to INFO level, with:
logging.getLogger().setLevel(logging.INFO)
if you want to see the logging output of this function.
"""
logging.info(f"Opening file {filename}")
hf = h5py.File(filename, "r")
tot_frames = 0
dim = -1
def get_num_frames(shape):
# Returns product of shape[0],shape[1],...,shape[-2]
num_frames = 1
for i in shape[:-1]:
num_frames *= i
return num_frames
for key in hf.keys():
dset = hf[key]
shape = list(dset.shape)
if dim == -1:
dim = shape[-1]
else:
assert (
dim == shape[-1]
), "Dataset must have consistent dimension (last element of shape"
tot_frames += get_num_frames(shape)
logging.info(f"read_data: tot_frames = {tot_frames}")
ans = np.empty((tot_frames, dim), dtype=np.float16)
cur_pos = 0
for key in hf.keys():
array = hf[key][:] # [:] gets it as NumPy array (I believe).
array = np.ascontiguousarray(array).reshape(-1, dim)
num_frames = array.shape[0]
ans[cur_pos : cur_pos + num_frames, :] = array # noqa E203
cur_pos += num_frames
assert cur_pos == tot_frames
# Shuffle the rows of ans.
np.random.shuffle(ans)
ans_torch = torch.from_numpy(ans)
valid_proportion = 0.05
valid_frames = valid_proportion * tot_frames
if valid_frames > 10000:
valid_frames = 10000
train_frames = tot_frames - valid_frames
logging.info(
f"read_data: train_frames={train_frames}, valid_frames={valid_frames}"
)
# return (train, valid)
return ans_torch[valid_frames:tot_frames], ans_torch[:valid_frames]