mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
add codebook loss
This commit is contained in:
parent
cbc9c50bfc
commit
979f574259
@ -0,0 +1,75 @@
|
||||
# A copy from:
|
||||
# https://github.com/danpovey/quantization/blob/master/quantization/checkpoint.py
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import Tensor
|
||||
from typing import Tuple, Callable
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, function: Callable, *args):
|
||||
# `function` must return either a Tensor or a tuple of Tensors
|
||||
ctx.function = function
|
||||
ctx.args = [x.detach() if isinstance(x, Tensor) else x
|
||||
for x in args]
|
||||
for i in range(len(ctx.args)):
|
||||
if isinstance(args[i], Tensor) and args[i].requires_grad:
|
||||
ctx.args[i].requires_grad = True
|
||||
with torch.no_grad():
|
||||
ans = function(*args)
|
||||
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *ans_grads):
|
||||
if not any([ a is not None for a in ans_grads]):
|
||||
return [None] * len(ctx.args)
|
||||
with torch.enable_grad():
|
||||
ans = ctx.function(*ctx.args)
|
||||
if isinstance(ans, Tensor):
|
||||
assert len(ans_grads) == 1
|
||||
loss = (ans * ans_grads[0]).sum()
|
||||
else:
|
||||
assert len(ans_grads) == len(ans)
|
||||
loss = torch.stack([ (a * g).sum() for a, g in zip(ans, ans_grads)
|
||||
if g is not None ]).sum()
|
||||
|
||||
loss.backward()
|
||||
return tuple([None] + [ a.grad if isinstance(a, Tensor) else None for a in ctx.args ])
|
||||
|
||||
|
||||
|
||||
def checkpoint(function, *args):
|
||||
return CheckpointFunction.apply(function, *args)
|
||||
|
||||
|
||||
|
||||
|
||||
def _test1():
|
||||
x = torch.Tensor([0])
|
||||
y = torch.Tensor([1])
|
||||
y.requires_grad = True
|
||||
l = lambda x, y, trash: torch.stack((x, y))
|
||||
ans = checkpoint(l, x, y, None)
|
||||
#ans = l(x, y, None)
|
||||
print("ans = ", ans)
|
||||
(-ans).sum().backward()
|
||||
print("y grad = ", y.grad)
|
||||
|
||||
|
||||
def _test2():
|
||||
x = torch.Tensor([0])
|
||||
y = torch.Tensor([1])
|
||||
x.requires_grad = True
|
||||
l = lambda x, y, trash: torch.stack((x, y))
|
||||
ans = checkpoint(l, x, y, None)
|
||||
ans = checkpoint(torch.sum, ans)
|
||||
#ans = l(x, y, None)
|
||||
print("ans = ", ans)
|
||||
(-ans).backward()
|
||||
print("x grad = ", x.grad)
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test1()
|
||||
_test2()
|
@ -23,6 +23,8 @@ from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
from vq_prediction import JointCodebookLoss
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
@ -38,6 +40,7 @@ class Transducer(nn.Module):
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
num_codebooks: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -67,6 +70,11 @@ class Transducer(nn.Module):
|
||||
encoder_dim, vocab_size, initial_speed=0.5
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||
# Two consecutive output frames of hubert are concatenated together.
|
||||
# Detailed in following `forward` function.
|
||||
self.codebook_loss_net = JointCodebookLoss(
|
||||
predictor_channels=encoder_dim, num_codebooks=num_codebooks * 2
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -77,6 +85,7 @@ class Transducer(nn.Module):
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
warmup: float = 1.0,
|
||||
codebook_indices: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -100,6 +109,8 @@ class Transducer(nn.Module):
|
||||
warmup:
|
||||
A value warmup >= 0 that determines which modules are active, values
|
||||
warmup > 1 "are fully warmed up" and all modules will be active.
|
||||
codebook_indices:
|
||||
codebook_indices from a teacher model.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
@ -115,7 +126,45 @@ class Transducer(nn.Module):
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
||||
encoder_out, x_lens, middle_layer_output = self.encoder(
|
||||
x, x_lens, warmup=warmup
|
||||
)
|
||||
|
||||
if self.training:
|
||||
# Do distillation.
|
||||
assert codebook_indices is not None
|
||||
assert hasattr(self, "codebook_loss_net")
|
||||
|
||||
# Output rate of hubert is 50 frames per second,
|
||||
# while that of current encoder is 25.
|
||||
# Following code handling two issues:
|
||||
# 1.
|
||||
# Roughly speaking, to generate another frame output,
|
||||
# hubert needes extra two frames,
|
||||
# while current encoder needs extra four frames.
|
||||
# Suppose there are only extra three frames provided,
|
||||
# hubert will generate another frame while current encoder does nothing.
|
||||
# 2.
|
||||
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
|
||||
# learns from 50 frames teacher output, two successive frames of teacher model
|
||||
# output is concatenated together.
|
||||
t_expected = middle_layer_output.shape[1]
|
||||
N, T, C = codebook_indices.shape
|
||||
|
||||
# Handling issue 1.
|
||||
if T >= t_expected * 2:
|
||||
codebook_indices = codebook_indices[:, : t_expected * 2, :]
|
||||
# Handling issue 2.
|
||||
codebook_indices = codebook_indices.reshape(N, t_expected, C * 2)
|
||||
|
||||
assert middle_layer_output.shape[1] == codebook_indices.shape[1]
|
||||
codebook_loss = self.codebook_loss_net(
|
||||
middle_layer_output, codebook_indices
|
||||
)
|
||||
else:
|
||||
# when codebook index is not available.
|
||||
codebook_loss = None
|
||||
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
@ -190,4 +239,4 @@ class Transducer(nn.Module):
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
return (simple_loss, pruned_loss, codebook_loss)
|
||||
|
@ -0,0 +1,193 @@
|
||||
# A copy from:
|
||||
# https://github.com/danpovey/quantization/blob/master/quantization/prediction.py
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import Tensor
|
||||
from typing import Tuple, Optional
|
||||
from checkpoint import checkpoint # from current directory.. could not get relative import to work..
|
||||
|
||||
|
||||
# functional version of joint codebook loss, added so that we can more easily implement
|
||||
# checkpointing to save memory.
|
||||
def joint_codebook_loss(predictor: Tensor,
|
||||
codebook_indexes: Tensor,
|
||||
linear1_weight: Tensor,
|
||||
linear1_bias: Optional[Tensor],
|
||||
codebook_embedding_weight: Tensor,
|
||||
linear2_weight: Tensor,
|
||||
linear2b_weight: Tensor,
|
||||
linear2_bias: Tensor,
|
||||
ignore_index: int,
|
||||
reduction: str) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
predictor: predictor tensor of shape (*, predictor_channels)
|
||||
codebook_indexes: codebook indexes of shape (*, num_codebooks)
|
||||
linear1_weight: weight of shape (hidden_channels, predictor_channels)
|
||||
linear1_bias: optional bias of shape (hidden_channels,)
|
||||
codebook_embedding_weight: weight of shape ((num_codebooks - 1) * codebook_size,
|
||||
hidden_channels)
|
||||
linear2_weight: weight of shape (num_codebooks, codebook_size,
|
||||
hidden_channels)
|
||||
linear2b_weight: weight of shape (num_codebooks, codebook_size,
|
||||
predictor_dim)
|
||||
linear2_bias: bias of shape (num_codebooks, codebook_size)
|
||||
ignore_index: index to ignore in cross entropy loss, e.g. -100
|
||||
reduction: reduction in cross entropy loss, e.g. 'sum'
|
||||
"""
|
||||
num_codebooks = codebook_indexes.shape[-1]
|
||||
predictor_channels = predictor.shape[-1]
|
||||
hidden_channels = linear1_weight.shape[0]
|
||||
codebook_size = codebook_embedding_weight.shape[0] // (num_codebooks - 1)
|
||||
|
||||
codebook_indexes = codebook_indexes.to(torch.int64)
|
||||
assert list(predictor.shape[:-1]) == list(codebook_indexes.shape[:-1])
|
||||
predictor = predictor.reshape(-1, predictor.shape[-1]) # (N, predictor_channels)
|
||||
codebook_indexes = codebook_indexes.reshape(-1, codebook_indexes.shape[-1])
|
||||
first_indexes = codebook_indexes[:,:-1] # all but last codebook indexes; (N, num_codebooks-1)
|
||||
|
||||
# do clamp(min=0) to avoid errors on padding (-100).. these frames will
|
||||
# later be ignored in the loss, so the value can be treated as a don't-care.
|
||||
first_indexes = first_indexes.clamp(min=0) + torch.arange(0, (num_codebooks - 1) * codebook_size,
|
||||
step=codebook_size,
|
||||
device=first_indexes.device) # (N, num_codebooks-1)
|
||||
|
||||
first_embeddings_scale = 0.5 * ((hidden_channels / num_codebooks) ** 0.5)
|
||||
first_embeddings = torch.nn.functional.embedding(first_indexes,
|
||||
codebook_embedding_weight) * first_embeddings_scale # (N, num_codebooks-1, hidden_channels)
|
||||
|
||||
|
||||
hidden_predictor = torch.nn.functional.linear(predictor, linear1_weight, linear1_bias)
|
||||
all_embeddings = torch.cat((hidden_predictor.unsqueeze(1),
|
||||
first_embeddings),
|
||||
dim=1) # (N, num_codebooks, hidden_channels)
|
||||
|
||||
# after cumsum, all positions will contain a contribution from 'hidden_predictor'; and
|
||||
# will also contain contributions from all *previous* codebooks. Here, "position" means
|
||||
# a position in {0..num_codebooks-1}
|
||||
all_embeddings = torch.cumsum(all_embeddings, dim=1) # (N, num_codebooks, hidden_channels)
|
||||
|
||||
all_embeddings = torch.nn.functional.relu(all_embeddings)
|
||||
|
||||
logprobs = torch.matmul(all_embeddings.transpose(0, 1), # (num_codebooks, N, hidden_channels)
|
||||
linear2_weight.transpose(1, 2) # (num_codebooks, hidden_channels, codebook_size)
|
||||
).transpose(0, 1) # (N, num_codebooks, codebook_size)
|
||||
|
||||
logprobs += torch.matmul(predictor, # (N, predictor_channels)
|
||||
linear2b_weight.transpose(1, 2) # (num_codebooks, predictor_channels, codebook_size)
|
||||
).transpose(0, 1) # (N, num_codebooks, codebook_size)
|
||||
|
||||
logprobs += linear2_bias
|
||||
logprobs = logprobs.log_softmax(dim=2) # (N, num_codebooks, codebook_size)
|
||||
|
||||
return torch.nn.functional.cross_entropy(logprobs.reshape(-1, codebook_size),
|
||||
codebook_indexes.reshape(-1),
|
||||
ignore_index=ignore_index,
|
||||
reduction=reduction)
|
||||
|
||||
|
||||
|
||||
class JointCodebookLoss(nn.Module):
|
||||
"""
|
||||
This module predicts a group of codebook indexes from a vector. The idea is that
|
||||
you have a number of codebooks (probably jointly trained), from class Quantizer,
|
||||
and you want to predict the probabilities of the codebook entries based on some
|
||||
predictor that you are training.
|
||||
|
||||
The simplest thing would be to project the vector using nn.Linear, then
|
||||
reshape and use logsoftmax to normalize the probabilities within each group,
|
||||
then compute the likelihood. However, this has a constraint that all the
|
||||
codebooks are predicted independently of each other. This module allows you
|
||||
to predict them jointly, by regressing each codebook on all previous codebooks.
|
||||
This is done with a nonlinearity in which the previous codebook entries are combined
|
||||
with the input predictor vector, so that the regression is not purely
|
||||
linear.
|
||||
|
||||
Args:
|
||||
predictor_dim: the number of features that we use to predict the codebook
|
||||
indexes, e.g. 2048 (will depend on your model).
|
||||
hidden_dim: a hidden dimension in the model; should be more than
|
||||
codebook_size, but may be less or more than predictor_dim.
|
||||
|
||||
num_codebooks: the number of codebooks that you are predicting;
|
||||
will likely be the same as the bytes_per_frame given to the
|
||||
QuantizerTrainer that you used to train the Quantizer you
|
||||
are predicting.
|
||||
codebook_size: number of entries per codebook (often 256)
|
||||
self_prediction: you can set this to false to enable prediction of
|
||||
codebooks by earlier-numbered codebooks
|
||||
hidden_dim: the hidden dimension per codebook (we use a 1-hidden-layer
|
||||
network, with a ReLU and then batchnorm).
|
||||
checkpoint: if true, reduce backprop memory at the expense of doing
|
||||
the computation twice.
|
||||
"""
|
||||
def __init__(self,
|
||||
predictor_channels: int,
|
||||
num_codebooks: int,
|
||||
hidden_channels: int = 512,
|
||||
codebook_size: int = 256,
|
||||
reduction: str = 'sum',
|
||||
ignore_index: int = -100,
|
||||
checkpoint: bool = True):
|
||||
super(JointCodebookLoss, self).__init__()
|
||||
|
||||
assert num_codebooks > 1 # we may later handle this specially.
|
||||
self.num_codebooks = num_codebooks
|
||||
self.codebook_size = codebook_size
|
||||
self.hidden_channels = hidden_channels
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
self.linear1 = nn.Linear(predictor_channels, hidden_channels)
|
||||
|
||||
# codebook_embedding is used to predict each codebook from previous
|
||||
# codebooks, so it's a joint, not independent, model. we'll multiply
|
||||
# this by hidden_channels ** 0.5 when we use it; this keeps the magnitude
|
||||
# small allows it to train fast enough (relatively speaking).
|
||||
self.codebook_embedding = nn.Embedding((num_codebooks - 1) * codebook_size,
|
||||
hidden_channels,
|
||||
_weight=torch.randn((num_codebooks - 1) * codebook_size,
|
||||
hidden_channels) * (hidden_channels ** -0.5))
|
||||
|
||||
self.linear2_weight = nn.Parameter(torch.randn(num_codebooks, codebook_size,
|
||||
hidden_channels) * (hidden_channels ** -0.5))
|
||||
self.linear2b_weight = nn.Parameter(torch.randn(num_codebooks, codebook_size,
|
||||
predictor_channels) * (predictor_channels ** -0.5))
|
||||
self.linear2_bias = nn.Parameter(torch.zeros(num_codebooks, codebook_size))
|
||||
|
||||
|
||||
def forward(self,
|
||||
predictor: Tensor,
|
||||
codebook_indexes: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Forward function.
|
||||
|
||||
Args:
|
||||
predictor: a Tensor of some real type, with shape (*, predictor_channels).
|
||||
codebook_indexes: a Tensor of integers, of shape (*, num_codebooks),
|
||||
where the '*' should be the same as for `predictor`. It will be
|
||||
converted to type torch.int64. Should contain indexes of codebook
|
||||
entries, in {0..codebook_size-1},
|
||||
or negative values which will be interpreted as "no codebook index here"
|
||||
(e.g. due to padding); we assume that each frame will either have
|
||||
all-negative or all-nonnegative indexes, meaning that (codebook_indexes >= 0)
|
||||
should not vary as you change the last index into it.
|
||||
|
||||
Returns:
|
||||
cross_entropy_loss, will be a total negated log-probability, assuming
|
||||
reduction == 'sum'.
|
||||
"""
|
||||
|
||||
args = (predictor, codebook_indexes,
|
||||
self.linear1.weight, self.linear1.bias,
|
||||
self.codebook_embedding.weight,
|
||||
self.linear2_weight,
|
||||
self.linear2b_weight,
|
||||
self.linear2_bias,
|
||||
self.ignore_index,
|
||||
self.reduction)
|
||||
if self.checkpoint:
|
||||
return checkpoint(joint_codebook_loss, *args)
|
||||
else:
|
||||
return joint_codebook_loss(*args)
|
Loading…
x
Reference in New Issue
Block a user