Refactor so there is no bottleneck, only prediction

This commit is contained in:
Daniel Povey 2021-09-19 15:38:34 +08:00
parent 0f29f35a42
commit b0dd4215fe

View File

@ -71,7 +71,7 @@ class ConformerTrunk(nn.Module):
def forward(
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
x:
@ -149,8 +149,7 @@ class BidirectionalConformer(nn.Module):
num_encoder_layers: Number of encoder layers in the "trunk" that
encodes the acoustic features
num_ctc_encoder_layers: Number of layers in the CTC encoder
that comes after the trunk (and possibly the discrete
bottleneck, if bypass_bottleneck == True.
that comes after the trunk.
These are just conformer encoder layers.
num_decoder_layers: Number of layers in the attention decoder;
this goes from the trunk to the word-pieces or phones.
@ -168,17 +167,14 @@ class BidirectionalConformer(nn.Module):
useful function is to prevent "trivial" solutions
such as collapse of the distribution to a single symbol,
or symbols that are highly correlated across time.
bypass_bottleneck: If true, bypass the discrete bottleneck
when predicting the CTC output and the decoder
that decodes the word-pieces or phones.
dropout: Dropout probability
cnn_module_kernel: Kernel size in forward conformer layers
is_bpe: If false, we'll add one (for EOS) to the number of
classes at the output of the decoder
use_feat_batchnorm: If true, apply batchnorm to the input features.
discrete_bottleneck_tot_classes: Total number of classes
discretization_tot_classes: Total number of classes
(across all groups) in the discrete bottleneck
discrete_bottleneck_num_groups: Number of groups of classes/symbols
discretization_num_groups: Number of groups of classes/symbols
in the discrete bottleneck
"""
def __init__(
@ -195,18 +191,15 @@ class BidirectionalConformer(nn.Module):
num_reverse_encoder_layers: int = 4,
num_reverse_decoder_layers: int = 4,
num_self_predictor_layers: int = 3,
bypass_bottleneck: bool = True,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
is_bpe: bool = False,
use_feat_batchnorm: bool = True,
discrete_bottleneck_tot_classes: int = 512,
discrete_bottleneck_num_groups: int = 4
discretization_tot_classes: int = 512,
discretization_num_groups: int = 4
) -> None:
super(BidirectionalConformer, self).__init__()
self.bypass_bottleneck = bypass_bottleneck
self.trunk = ConformerTrunk(num_features, subsampling_factor,
d_model, nhead, dim_feedforward,
num_trunk_encoder_layers, dropout,
@ -324,53 +317,64 @@ class BidirectionalConformer(nn.Module):
for _ in range(num_self_predictor_layers)])
self.discrete_bottleneck = DiscreteBottleneck(
self.sample_and_predict = SampleAndPredict(
dim=d_model,
tot_classes=discrete_bottleneck_tot_classes,
num_groups=discrete_bottleneck_num_groups)
tot_classes=discretization_tot_classes,
num_groups=discretization_num_groups)
def forward(self, x: Tensor, supervision: Optional[Supervisions] = None,
need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]:
need_softmax: bool = True) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
"""
Forward function that "encodes" the features.
Forward function that "encodes" the acoustic features through the "trunk"
(the shared part of the encoding of the encoding of the acoustic features)
Args:
x:
The input tensor. Its shape is [N, T, F], i.e. [batch_size, num_frames, num_features].
supervision:
Supervision in lhotse format (optional)
Supervision in lhotse format (optional; needed only for acoustic length
information)
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
(CAUTION: It contains length information, i.e., start and number of
frames, before subsampling). Used only to compute masking information.
need_softmax:
If true, the last output ("softmax") will be computed. This can be useful
in the reverse model, but only necessary if straight_through_scale != 1.0.
Returns: (memory, bn_memory, pos_emb, sampled, softmax, key_padding_mask), where:
Returns: (memory, pos_emb, key_padding_mask), where:
memory: a Tensor of shape [T, N, E] i.e. [T, batch_size, embedding_dim] where T
is actually a subsampled form of the num_frames of the input `x`.
If self.bypass_bottleneck, it will be taken before the discrete
bottleneck; otherwise, from after.
bn_memory: The same shape as `memory`, but comes after the discrete bottleneck
regardless of the value of self.bypass_bottleneck.
pos_emb: The relative positional embedding; will be given to ctc_encoder_forward()
sampled: a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes`
as given to the constructor. This will be needed for the 'reverse' model.
softmax: a "soft" version of `sampled`. Will only be returned if need_softmax == True;
else will be None.
key_padding_mask: The padding mask for the "memory" output, a Tensor of bool of
shape [N, T] (only if supervision was supplied, else None).
"""
encoder_output, pos_emb, memory_key_padding_mask = self.trunk(x, supervision)
memory, pos_emb, memory_key_padding_mask = self.trunk(x, supervision)
return memory, pos_emb, memory_key_padding_mask
bn_memory, sampled, softmax = self.discrete_bottleneck(encoder_output)
memory = encoder_output if self.bypass_bottleneck else bn_memory
def sample_forward(self, memory: Tensor) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]:
"""
Given the "memory" from forward(), run the sample_and_redict module.
See documentation for forward() of class SampleAndPredict for more info.
return (memory, bn_memory, pos_emb, sampled, softmax, memory_key_padding_mask)
Returns (sampled, softmax, positive_embed_shifted, negative_embed_shifted),
where positive_embed_shifted, for instance, is positive_embed
shifted by one so that positive_embed_shifted[t] == positive_embed[t-1], as in:
(T, N, E) = positive_embed.shape
positive_embed_shifted = torch.cat((torch.zeros(1, N, E), positive_embed[:-1,:,:]), dim=0)
"""
(sampled, softmax, positive_embed, negative_embed) = self.sample_and_predict(memory)
(T, N, E) = memory.shape
device = memory.device
zeros = torch.zeros(1, N, E).to(memory.device)
negative_embed_shifted = torch.cat((zeros, negative_embed[:-1,:,:]), dim=0)
positive_embed_shifted = torch.cat((zeros, positive_embed[:-1,:,:]), dim=0)
return (sampled, softmax, positive_embed_shifted, negative_embed_shifted)
def decoder_forward(
self,
@ -479,11 +483,10 @@ class BidirectionalConformer(nn.Module):
def self_prediction_forward(
self,
bn_memory_shifted: torch.Tensor,
negative_embed_shifted: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
sampled: torch.Tensor,
softmax: Optional[torch.Tensor],
reverse_grad: bool = True) -> Tensor:
softmax: Optional[torch.Tensor]) -> Tensor:
"""
Returns the total log-prob of the the labels sampled in the discrete
bottleneck layer, as predicted using a relatively simple model that
@ -491,45 +494,29 @@ class BidirectionalConformer(nn.Module):
[Appears on the denominator of an expression for mutual information].
Args:
bn_memory_shifted:
It's the bn_memory output of forward(), with shape [T, N, E], shifted
by one so that bn_shifted_memory[t] == bn_memory[t-1], as in:
(T, N, E) = bn_memory.shape
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
negative_embed_shifted:
The negative_embed_shifted output of self.sample_forward(), with shape [T, N, E]
memory_key_padding_mask:
The padding mask from the encoder, of shape [N, T], boolean, True
for masked locations.
sampled: sampled and interpolated one-hot values, as a Tensor of shape [T, N, C]
where C corresponds to `discrete_bottleneck_tot_classes`
where C corresponds to `discretization_tot_classes`
as given to the constructor. This will be needed for the 'reverse'
model.
softmax: is a "soft" version of `sampled`; if None, will default to `sampled`.
reverse_grad: will likely be true. If true, the gradient is reversed twice
in this computation, so that we train predictors with the correct
gradient, i.e. to predict, not anti-predict (since the return value
of this function will appear with positive, not negative, sign in the
loss function, so will be minimized).
The gradient w.r.t. the non-self inputs to this function, though (i.e.
bn_memory_shifted, sampled, softmax) will not be reversed, though.
Returns:
A scalar tensor, the **sum** of label smoothing loss over utterances
A scalar tensor, the **sum** of the log-prob loss over utterances
in the batch without any normalization.
"""
if reverse_grad:
# Reversing gradient for bn_memory_shifted puts the gradient back into
# the correct sign; we reversed it in
# self.discrete_bottleneck.compute_prob(), in order that
# self.self.predictor_encoder can learn to predict.
# (You have to read the code in reverse, to reason about
# what happens to the gradients).
bn_memory_shifted = ReverseGrad.apply(bn_memory_shifted)
# no mask is needed for self_predictor_encoder; its CNN
# layer uses left-padding only, making it causal.
predictor = self.self_predictor_encoder(bn_memory_shifted)
# layer uses left-padding only, making it causal, so the mask
# is redundant (it wouldn't affect any of the
# outputs we care about).
predictor = self.self_predictor_encoder(negative_embed_shifted)
prob = self.discrete_bottleneck.compute_prob(predictor,
prob = self.sample_and_predict.compute_prob(predictor,
sampled, softmax,
memory_key_padding_mask,
reverse_grad=True)
@ -538,7 +525,7 @@ class BidirectionalConformer(nn.Module):
def reverse_decoder_forward(
self,
bn_memory_shifted: torch.Tensor,
positive_embed_shifted: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
sampled: torch.Tensor,
softmax: Optional[torch.Tensor],
@ -553,15 +540,14 @@ class BidirectionalConformer(nn.Module):
supervision word-sequence.
Args:
bn_memory_shifted:
It's the bn_memory output of forward(), with shape [T, N, E], shifted
by one so that shifted_memory[t] == bn_memory[t-1], as in:
(T, N, E) = memory.shape
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
positive_embed_shifted:
It's the positive_embed_shifted output of self.sample_forward(), with
shape [T, N, E]
memory_key_padding_mask:
The padding mask from the encoder.
sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes`
as given to the constructor. This will be needed for the 'reverse' model.
sampled: is a Tensor of shape [T, N, C] where C corresponds to
`discretization_tot_classes` as given to the constructor.
This will be needed for the 'reverse' model.
softmax: is a "soft" version of `sampled`; if None, will default to `sampled`.
token_ids:
A list-of-list IDs. Each sublist contains IDs for an utterance.
@ -583,7 +569,7 @@ class BidirectionalConformer(nn.Module):
token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ]
tokens_padded = pad_sequence(token_ids_tensors, batch_first=True,
padding_value=padding_id).to(bn_memory_shifted.device)
padding_value=padding_id).to(positive_embed_shifted.device)
print("tokens_padded = ", tokens_padded)
tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id)
@ -601,17 +587,17 @@ class BidirectionalConformer(nn.Module):
src_key_padding_mask=tokens_key_padding_mask)
# token_memory is of shape (S, N, C), if S is length of token sequence.
T = bn_memory_shifted.shape[0]
T = positive_embed_shifted.shape[0]
# the targets, here, are the hidden discrete symbols we are predicting
tgt_mask = generate_square_subsequent_mask(T, device=bn_memory_shifted.device)
tgt_mask = generate_square_subsequent_mask(T, device=positive_embed_shifted.device)
hidden_predictor = self.reverse_decoder(
tgt=bn_memory_shifted,
tgt=positive_embed_shifted,
memory=token_memory,
tgt_mask=tgt_mask,
memory_key_padding_mask=tokens_key_padding_mask)
total_prob = self.discrete_bottleneck.compute_prob(
total_prob = self.sample_and_predict.compute_prob(
hidden_predictor,
sampled,
softmax,
@ -676,6 +662,9 @@ class ReverseGrad(torch.autograd.Function):
def backward(ctx, x_grad):
return -x_grad
def reverse_gradient(x: Tensor) -> Tensor:
return ReverseGrad.apply(x)
class DebugGrad(torch.autograd.Function):
@staticmethod
@ -693,19 +682,21 @@ class DebugGrad(torch.autograd.Function):
class DiscreteBottleneck(nn.Module):
"""
This layer forces its input through an information bottleneck via
a discretization operation with sampling, and allows you to
predict the likelihood of those discretized values.
We use the torch-flow-sampling
package for this, to provide a differentiable softmax that should be
much better than Gumbel in terms of actually giving an information
bottleneck.
class SampleAndPredict(nn.Module):
"""This module discretizes its input and lets you predict the
discrete classes.
We use the torch-flow-sampling package for this, to provide a differentiable
softmax that should be much better than Gumbel in terms of actually giving
an information bottleneck.
(However, if straight_through_scale == 1.0, which actually seems
to be working fine so far, it's the same as just sampling from the
categorical distribution and using straight-through derivatives
(i.e. the derivatives are as if that output had just been a softmax).
This may depend somewhat on the model; straight_through_scale == 0.0
is definitely safer from a correctness point of view.
Args:
dim: The input and output dimension of the discrete bottleneck
operation.
dim: The input feature dimension
tot_classes: The total number of classes (across all groups
of classes); each group is separately discretized
num_groups: The number of groups of classes; discretization
@ -728,10 +719,10 @@ class DiscreteBottleneck(nn.Module):
1.0 - straight_through_scale.
min_prob_ratio: For any class whose average softmax
output, for a given minibatch, is less than
min_prob_ratio times
include_predictor: If true, include the parameters
necessary to predict the likelihoods of the
classes from some kind of input embedding.
min_prob_ratio times the average probability,
boost its probability; this is a mechanism
to avoid "losing" classes, we are hoping it won't really
be necessary in practice.
"""
def __init__(
self,
@ -739,12 +730,11 @@ class DiscreteBottleneck(nn.Module):
tot_classes: int,
num_groups: int,
interp_prob: float = 1.0,
straight_through_scale: float = 0.333,
straight_through_scale: float = 0.0,
min_prob_ratio: float = 0.1,
include_predictor: bool = True
):
super(DiscreteBottleneck, self).__init__()
self.norm_in = nn.LayerNorm(dim)
super(SampleAndPredict, self).__init__()
self.linear1 = nn.Linear(dim, tot_classes)
self.num_groups = num_groups
@ -753,25 +743,24 @@ class DiscreteBottleneck(nn.Module):
self.min_prob_ratio = min_prob_ratio
self.tot_classes = tot_classes
self.classes_per_group = tot_classes // num_groups
# prob_boost relates to the min_prob_ratio setting. It's not configurable for now.
self.prob_boost = 1.0e-05
# class_probs is a rolling mean of the output of the sampling operation.
# When any element of it gets below self.min_prob_ratio / self.classes_per_group,
# we boost the class's probability by adding self.prob_boost to
# that element of self.class_offset
self.class_probs_decay = 0.9
self.class_probs_decay = 0.95
self.register_buffer('class_probs', torch.ones(tot_classes) / self.classes_per_group)
# class_offsets is a bias term that we add to logits before the sampling
# operation in order to enforce that no class is too infrequent
# (c.f. 'min_prob_ratio').
self.register_buffer('class_offsets', torch.zeros(tot_classes))
self.linear2 = nn.Linear(tot_classes, dim, bias=False)
self.norm_out = nn.LayerNorm(dim)
if include_predictor:
# pred_linear predicts the class probabilities from a predictor
# embedding.
# embedding of dimension 'dim' supplied by the user.
self.pred_linear = nn.Linear(dim, tot_classes)
if self.num_groups > 1:
@ -790,7 +779,13 @@ class DiscreteBottleneck(nn.Module):
# [ True, True, True, True],
# [ True, True, True, True]])
self.register_buffer('pred_cross_mask',
((torch.arange(d) // c).unsqueeze(1) >= (torch.arange(d) // c).unsqueeze(0)))
((torch.arange(d) // c).unsqueeze(1) >=
(torch.arange(d) // c).unsqueeze(0)))
# linear2 and post_layer_norm come after the sampling.
self.linear2 = nn.Linear(tot_classes, dim)
self.post_layer_norm = nn.LayerNorm(dim)
self._reset_parameters()
def _reset_parameters(self):
@ -798,17 +793,14 @@ class DiscreteBottleneck(nn.Module):
torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5))
def forward(self, x: Tensor, need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
def forward(self, x: Tensor, need_softmax: bool = True) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]:
"""
Forward computation.
Forward computation. See also compute_prob().
Args:
x: The input tensor, of shape (S, N, E) where S is the sequence length,
N is the batch size and E is the embedding dim.
Returns (embeddding, sampled, softmax), where:
Returns (sampled, softmax, positive_embed, negative_embed), where:
embedding: of shape (S, N, E) where E is the embedding dimension (`dim` arg
to the constructor), this is the output embedding; it is projected from the
sampled class probabilities.
sampled: of shape (S, N, C) where C is the `tot_classes` to the
constructor, these are the sampled one-hot vectors or interpolations
thereof. They will be needed if we try to predict the discrete values
@ -820,19 +812,31 @@ class DiscreteBottleneck(nn.Module):
expectation of the result of sampling -> lower-variance derivatives.
This is unnecessary if straight_through_scale == 1.0, since in that
case it would not affect the backpropagated derivatives.
positive_embed: The samples projected back down to the embedding
dimension, and layer-normed (`dim` passed to the constructor).
negative_embed: This is numerically the same value as positive_embed,
but has its gradients reversed prior to the projection and
LayerNorm. This is intended to be used for terms that appear
with opposite sign in the loss function, to be fed to
something whose gradient is already (going to be) reversed:
specifically, the self-prediction network.
"""
x = self.norm_in(x) * 5 # * 5 gives lower entropy.. starts training faster..
x = self.linear1(x)
x = self.linear1(x * 5) # multiplying 5 gives lower entropy, makes it
# begin training faster..
if self.min_prob_ratio > 0.0:
x = x + self.class_offsets
(S, N, tot_classes) = x.shape
x = x.reshape(S, N, self.num_groups, self.classes_per_group)
# This is a little wasteful since we already compute the softmax
# inside 'flow_sample'.
# This is a little wasteful since we already compute the softmax inside
# 'flow_sample'.
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
if random.random() < 0.001:
# Some info that's useful for debug.
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
negentropy = (softmax_temp * logsoftmax_temp).sum(-1).mean()
@ -841,8 +845,9 @@ class DiscreteBottleneck(nn.Module):
global_log_softmax = (global_softmax + 1.0e-20).log()
global_negentropy = (global_softmax * global_log_softmax).sum(-1).mean()
print("Entropy = ", -negentropy.to('cpu').item(), ", averaged entropy = ", -global_negentropy.to('cpu').item())
print("SampleAndPredict: entropy = ",
-negentropy.to('cpu').item(), ", averaged entropy = ",
-global_negentropy.to('cpu').item())
x = torch_flow_sampling.flow_sample(x,
@ -854,19 +859,22 @@ class DiscreteBottleneck(nn.Module):
sampled = x
if self.training and False:
if self.training and self.min_prob_ratio > 0.0:
mean_class_probs = torch.mean(x.detach(), dim=(0,1))
self.class_probs = (self.class_probs * self.class_probs_decay +
mean_class_probs * (1.0 - self.class_probs_decay))
prob_floor = self.min_prob_ratio / self.classes_per_group
self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost
embedding = self.norm_out(self.linear2(x))
#if random.random() < 0.01:
# return (embedding, DebugGrad.apply(sampled, "sampled"), DebugGrad.apply(softmax, "softmax"))
#else:
return (embedding, sampled, softmax)
positive_embed = self.post_layer_norm(self.linear2(sampled))
negative_embed = self.post_layer_norm(self.linear2(reverse_gradient(sampled)))
if random.random() < 0.002:
return (DebugGrad.apply(sampled, "sampled"), DebugGrad.apply(softmax, "softmax"),
positive_embed, negative_embed)
else:
return (sampled, softmax, positive_embed, negative_embed)
def compute_prob(self, x: Tensor, sampled: Tensor, softmax: Optional[Tensor],
@ -891,23 +899,23 @@ class DiscreteBottleneck(nn.Module):
padding_mask: Optionally, a boolean tensor of shape (N, S), i.e.
(batch_size, sequence_length), with True in masked positions
that are to be ignored in the sum of probabilities.
reverse_grad: If true, negate the gradient that is passed back
to 'x' and to the modules self.pred_linear and pred_cross.
This will be useful in computing a loss function that has
a likelihood term with negative sign (i.e. the self-prediction).
We'll later need negate the gradient one more more time
where we give the input to the prediction module that
generated 'x'.
We'll later need to negate this gradient one more more time
(it's expected, when reverse_grad == True, that x would
derive somehow from `negative_embed`, so that the gradient
will eventually go back to the correct sign.)
Returns a scalar Tensor represnting the total probability.
"""
if reverse_grad:
sampled = ReverseGrad.apply(sampled)
sampled = reverse_gradient(sampled)
if softmax is None:
softmax = sampled
elif reverse_grad:
softmax = ReverseGrad.apply(softmax)
softmax = reverse_gradient(softmax)
logprobs = self.pred_linear(x)
@ -942,7 +950,7 @@ class DiscreteBottleneck(nn.Module):
tot_prob = (logprobs * softmax).sum()
if reverse_grad:
tot_prob = ReverseGrad.apply(tot_prob)
tot_prob = reverse_gradient(tot_prob)
return tot_prob
@ -1814,28 +1822,27 @@ def _test_bidirectional_conformer():
print("tokens = ", tokens)
print("supervision = ", supervision)
# memory: [T, N, C]
(memory, bn_memory, pos_emb, sampled, softmax, key_padding_mask) = m(feats, supervision)
(memory, pos_emb, key_padding_mask) = m(feats, supervision)
# ctc_output: [N, T, C].
ctc_output = m.ctc_encoder_forward(memory, pos_emb, key_padding_mask)
(sampled, softmax, positive_embed_shifted, negative_embed_shifted) = m.sample_forward(memory)
decoder_logprob = m.decoder_forward(memory, key_padding_mask, tokens,
sos_id=1,
eos_id=2)
print("decoder logprob = ", decoder_logprob)
(T, N, E) = bn_memory.shape
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
reverse_decoder_logprob = m.reverse_decoder_forward(
bn_memory_shifted, key_padding_mask,
positive_embed_shifted, key_padding_mask,
sampled, softmax, tokens,
sos_id=1, eos_id=2, padding_id=0)
print("reverse decoder logprob = ", reverse_decoder_logprob)
self_prediction_logprob = m.self_prediction_forward(
bn_memory_shifted, key_padding_mask,
negative_embed_shifted, key_padding_mask,
sampled, softmax)
print("self prediction logprob = ", self_prediction_logprob)
@ -1850,24 +1857,24 @@ def _test_discrete_bottleneck():
tot_classes = 256
num_groups = 8
interp_prob = 1.0
straight_through_scale = 0.0 # will change
straight_through_scale = 1.0 # will change
need_softmax = True
b = DiscreteBottleneck(dim, tot_classes, num_groups,
b = SampleAndPredict(dim, tot_classes, num_groups,
interp_prob, straight_through_scale).to(device)
self_predictor = nn.Linear(tot_classes, dim).to(device)
from_feats_predictor = nn.Linear(dim, dim).to(device)
model = nn.ModuleList([b, self_predictor, from_feats_predictor])
from_negative_embed_predictor = nn.Linear(dim, dim).to(device)
model = nn.ModuleList([b, from_feats_predictor, from_negative_embed_predictor])
model.train()
optim = torch.optim.Adam(params=model.parameters(),
lr=3.0e-04)
scale = 0.3 # determines the feature correlation..should be between 0 and 1.
scale = 0.5 # determines the feature correlation..should be between 0 and 1.
#https://en.wikipedia.org/wiki/Mutual_information#Linear_correlation, -0.5 log(1 - rho^2)..
# scale corresponds to rho^2, rho being sqrt(scale).
mutual_information = dim * -0.5 * math.log(1.0 - scale)
@ -1886,37 +1893,34 @@ def _test_discrete_bottleneck():
#print(f"norm(feats) ={feats.norm()} vs. norm(feats2) = {feats2.norm()}")
bn_memory, sampled, softmax = b(feats)
sampled, softmax, positive_embed, negative_embed = b(feats)
E = dim
negative_embed_shifted = torch.cat((torch.zeros(1, N, E).to(device),
negative_embed[:-1,:,:]), dim=0)
positive_embed_shifted = torch.cat((torch.zeros(1, N, E).to(device),
positive_embed[:-1,:,:]), dim=0)
# using feats2 instead of feats will limit the mutual information,
# to the MI between feats and feats2, which we computed and printed
# above as mutual_information.
predictor = from_feats_predictor(feats2)
prob = b.compute_prob(predictor, sampled, softmax)
if True:
sampled_reversed = ReverseGrad.apply(sampled)
predictor_reversed = self_predictor(sampled_reversed)
predictor_shifted = from_negative_embed_predictor(negative_embed_shifted)
if True:
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
predictor_reversed[:-1,:,:]), dim=0)
else:
# skip shifting.. want to see the effect..
predictor_reversed_shifted = predictor_reversed
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
self_prob = b.compute_prob(predictor_shifted, sampled, softmax,
reverse_grad=True)
normalized_self_prob = (self_prob / (T * N)).to('cpu').item()
normalized_self_prob = (self_prob / (T * N))
normalized_prob = (prob / (T * N)).to('cpu').item()
normalized_prob = (prob / (T * N))
loss = -(prob - self_prob)
normalized_loss = loss / (T * N)
normalized_loss = -(normalized_prob - normalized_self_prob)
if i % 200 == 0:
print(f"Epoch {epoch}, iteration {i}, normalized loss/frame is {-normalized_prob} - {-normalized_self_prob} = {normalized_loss.to('cpu').item()}")
print(f"Epoch {epoch}, iteration {i}, normalized loss/frame is {-normalized_prob.to('cpu').item()} - {-normalized_self_prob.to('cpu').item()} = {normalized_loss.to('cpu').item()}")
normalized_loss.backward()