Refactor so there is no bottleneck, only prediction

This commit is contained in:
Daniel Povey 2021-09-19 15:38:34 +08:00
parent 0f29f35a42
commit b0dd4215fe

View File

@ -71,7 +71,7 @@ class ConformerTrunk(nn.Module):
def forward( def forward(
self, x: torch.Tensor, supervision: Optional[Supervisions] = None self, x: torch.Tensor, supervision: Optional[Supervisions] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
""" """
Args: Args:
x: x:
@ -149,8 +149,7 @@ class BidirectionalConformer(nn.Module):
num_encoder_layers: Number of encoder layers in the "trunk" that num_encoder_layers: Number of encoder layers in the "trunk" that
encodes the acoustic features encodes the acoustic features
num_ctc_encoder_layers: Number of layers in the CTC encoder num_ctc_encoder_layers: Number of layers in the CTC encoder
that comes after the trunk (and possibly the discrete that comes after the trunk.
bottleneck, if bypass_bottleneck == True.
These are just conformer encoder layers. These are just conformer encoder layers.
num_decoder_layers: Number of layers in the attention decoder; num_decoder_layers: Number of layers in the attention decoder;
this goes from the trunk to the word-pieces or phones. this goes from the trunk to the word-pieces or phones.
@ -168,17 +167,14 @@ class BidirectionalConformer(nn.Module):
useful function is to prevent "trivial" solutions useful function is to prevent "trivial" solutions
such as collapse of the distribution to a single symbol, such as collapse of the distribution to a single symbol,
or symbols that are highly correlated across time. or symbols that are highly correlated across time.
bypass_bottleneck: If true, bypass the discrete bottleneck
when predicting the CTC output and the decoder
that decodes the word-pieces or phones.
dropout: Dropout probability dropout: Dropout probability
cnn_module_kernel: Kernel size in forward conformer layers cnn_module_kernel: Kernel size in forward conformer layers
is_bpe: If false, we'll add one (for EOS) to the number of is_bpe: If false, we'll add one (for EOS) to the number of
classes at the output of the decoder classes at the output of the decoder
use_feat_batchnorm: If true, apply batchnorm to the input features. use_feat_batchnorm: If true, apply batchnorm to the input features.
discrete_bottleneck_tot_classes: Total number of classes discretization_tot_classes: Total number of classes
(across all groups) in the discrete bottleneck (across all groups) in the discrete bottleneck
discrete_bottleneck_num_groups: Number of groups of classes/symbols discretization_num_groups: Number of groups of classes/symbols
in the discrete bottleneck in the discrete bottleneck
""" """
def __init__( def __init__(
@ -195,18 +191,15 @@ class BidirectionalConformer(nn.Module):
num_reverse_encoder_layers: int = 4, num_reverse_encoder_layers: int = 4,
num_reverse_decoder_layers: int = 4, num_reverse_decoder_layers: int = 4,
num_self_predictor_layers: int = 3, num_self_predictor_layers: int = 3,
bypass_bottleneck: bool = True,
dropout: float = 0.1, dropout: float = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
is_bpe: bool = False, is_bpe: bool = False,
use_feat_batchnorm: bool = True, use_feat_batchnorm: bool = True,
discrete_bottleneck_tot_classes: int = 512, discretization_tot_classes: int = 512,
discrete_bottleneck_num_groups: int = 4 discretization_num_groups: int = 4
) -> None: ) -> None:
super(BidirectionalConformer, self).__init__() super(BidirectionalConformer, self).__init__()
self.bypass_bottleneck = bypass_bottleneck
self.trunk = ConformerTrunk(num_features, subsampling_factor, self.trunk = ConformerTrunk(num_features, subsampling_factor,
d_model, nhead, dim_feedforward, d_model, nhead, dim_feedforward,
num_trunk_encoder_layers, dropout, num_trunk_encoder_layers, dropout,
@ -324,53 +317,64 @@ class BidirectionalConformer(nn.Module):
for _ in range(num_self_predictor_layers)]) for _ in range(num_self_predictor_layers)])
self.discrete_bottleneck = DiscreteBottleneck( self.sample_and_predict = SampleAndPredict(
dim=d_model, dim=d_model,
tot_classes=discrete_bottleneck_tot_classes, tot_classes=discretization_tot_classes,
num_groups=discrete_bottleneck_num_groups) num_groups=discretization_num_groups)
def forward(self, x: Tensor, supervision: Optional[Supervisions] = None, def forward(self, x: Tensor, supervision: Optional[Supervisions] = None,
need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]: need_softmax: bool = True) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
""" """
Forward function that "encodes" the features. Forward function that "encodes" the acoustic features through the "trunk"
(the shared part of the encoding of the encoding of the acoustic features)
Args: Args:
x: x:
The input tensor. Its shape is [N, T, F], i.e. [batch_size, num_frames, num_features]. The input tensor. Its shape is [N, T, F], i.e. [batch_size, num_frames, num_features].
supervision: supervision:
Supervision in lhotse format (optional) Supervision in lhotse format (optional; needed only for acoustic length
information)
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
(CAUTION: It contains length information, i.e., start and number of (CAUTION: It contains length information, i.e., start and number of
frames, before subsampling). Used only to compute masking information. frames, before subsampling). Used only to compute masking information.
need_softmax:
If true, the last output ("softmax") will be computed. This can be useful
in the reverse model, but only necessary if straight_through_scale != 1.0.
Returns: (memory, bn_memory, pos_emb, sampled, softmax, key_padding_mask), where: Returns: (memory, pos_emb, key_padding_mask), where:
memory: a Tensor of shape [T, N, E] i.e. [T, batch_size, embedding_dim] where T memory: a Tensor of shape [T, N, E] i.e. [T, batch_size, embedding_dim] where T
is actually a subsampled form of the num_frames of the input `x`. is actually a subsampled form of the num_frames of the input `x`.
If self.bypass_bottleneck, it will be taken before the discrete If self.bypass_bottleneck, it will be taken before the discrete
bottleneck; otherwise, from after. bottleneck; otherwise, from after.
bn_memory: The same shape as `memory`, but comes after the discrete bottleneck
regardless of the value of self.bypass_bottleneck.
pos_emb: The relative positional embedding; will be given to ctc_encoder_forward() pos_emb: The relative positional embedding; will be given to ctc_encoder_forward()
sampled: a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes`
as given to the constructor. This will be needed for the 'reverse' model.
softmax: a "soft" version of `sampled`. Will only be returned if need_softmax == True;
else will be None.
key_padding_mask: The padding mask for the "memory" output, a Tensor of bool of key_padding_mask: The padding mask for the "memory" output, a Tensor of bool of
shape [N, T] (only if supervision was supplied, else None). shape [N, T] (only if supervision was supplied, else None).
""" """
encoder_output, pos_emb, memory_key_padding_mask = self.trunk(x, supervision) memory, pos_emb, memory_key_padding_mask = self.trunk(x, supervision)
return memory, pos_emb, memory_key_padding_mask
bn_memory, sampled, softmax = self.discrete_bottleneck(encoder_output)
memory = encoder_output if self.bypass_bottleneck else bn_memory def sample_forward(self, memory: Tensor) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]:
"""
Given the "memory" from forward(), run the sample_and_redict module.
See documentation for forward() of class SampleAndPredict for more info.
return (memory, bn_memory, pos_emb, sampled, softmax, memory_key_padding_mask) Returns (sampled, softmax, positive_embed_shifted, negative_embed_shifted),
where positive_embed_shifted, for instance, is positive_embed
shifted by one so that positive_embed_shifted[t] == positive_embed[t-1], as in:
(T, N, E) = positive_embed.shape
positive_embed_shifted = torch.cat((torch.zeros(1, N, E), positive_embed[:-1,:,:]), dim=0)
"""
(sampled, softmax, positive_embed, negative_embed) = self.sample_and_predict(memory)
(T, N, E) = memory.shape
device = memory.device
zeros = torch.zeros(1, N, E).to(memory.device)
negative_embed_shifted = torch.cat((zeros, negative_embed[:-1,:,:]), dim=0)
positive_embed_shifted = torch.cat((zeros, positive_embed[:-1,:,:]), dim=0)
return (sampled, softmax, positive_embed_shifted, negative_embed_shifted)
def decoder_forward( def decoder_forward(
self, self,
@ -479,11 +483,10 @@ class BidirectionalConformer(nn.Module):
def self_prediction_forward( def self_prediction_forward(
self, self,
bn_memory_shifted: torch.Tensor, negative_embed_shifted: torch.Tensor,
memory_key_padding_mask: torch.Tensor, memory_key_padding_mask: torch.Tensor,
sampled: torch.Tensor, sampled: torch.Tensor,
softmax: Optional[torch.Tensor], softmax: Optional[torch.Tensor]) -> Tensor:
reverse_grad: bool = True) -> Tensor:
""" """
Returns the total log-prob of the the labels sampled in the discrete Returns the total log-prob of the the labels sampled in the discrete
bottleneck layer, as predicted using a relatively simple model that bottleneck layer, as predicted using a relatively simple model that
@ -491,45 +494,29 @@ class BidirectionalConformer(nn.Module):
[Appears on the denominator of an expression for mutual information]. [Appears on the denominator of an expression for mutual information].
Args: Args:
bn_memory_shifted: negative_embed_shifted:
It's the bn_memory output of forward(), with shape [T, N, E], shifted The negative_embed_shifted output of self.sample_forward(), with shape [T, N, E]
by one so that bn_shifted_memory[t] == bn_memory[t-1], as in:
(T, N, E) = bn_memory.shape
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from the encoder, of shape [N, T], boolean, True The padding mask from the encoder, of shape [N, T], boolean, True
for masked locations. for masked locations.
sampled: sampled and interpolated one-hot values, as a Tensor of shape [T, N, C] sampled: sampled and interpolated one-hot values, as a Tensor of shape [T, N, C]
where C corresponds to `discrete_bottleneck_tot_classes` where C corresponds to `discretization_tot_classes`
as given to the constructor. This will be needed for the 'reverse' as given to the constructor. This will be needed for the 'reverse'
model. model.
softmax: is a "soft" version of `sampled`; if None, will default to `sampled`. softmax: is a "soft" version of `sampled`; if None, will default to `sampled`.
reverse_grad: will likely be true. If true, the gradient is reversed twice
in this computation, so that we train predictors with the correct
gradient, i.e. to predict, not anti-predict (since the return value
of this function will appear with positive, not negative, sign in the
loss function, so will be minimized).
The gradient w.r.t. the non-self inputs to this function, though (i.e.
bn_memory_shifted, sampled, softmax) will not be reversed, though.
Returns: Returns:
A scalar tensor, the **sum** of label smoothing loss over utterances A scalar tensor, the **sum** of the log-prob loss over utterances
in the batch without any normalization. in the batch without any normalization.
""" """
if reverse_grad:
# Reversing gradient for bn_memory_shifted puts the gradient back into
# the correct sign; we reversed it in
# self.discrete_bottleneck.compute_prob(), in order that
# self.self.predictor_encoder can learn to predict.
# (You have to read the code in reverse, to reason about
# what happens to the gradients).
bn_memory_shifted = ReverseGrad.apply(bn_memory_shifted)
# no mask is needed for self_predictor_encoder; its CNN # no mask is needed for self_predictor_encoder; its CNN
# layer uses left-padding only, making it causal. # layer uses left-padding only, making it causal, so the mask
predictor = self.self_predictor_encoder(bn_memory_shifted) # is redundant (it wouldn't affect any of the
# outputs we care about).
predictor = self.self_predictor_encoder(negative_embed_shifted)
prob = self.discrete_bottleneck.compute_prob(predictor, prob = self.sample_and_predict.compute_prob(predictor,
sampled, softmax, sampled, softmax,
memory_key_padding_mask, memory_key_padding_mask,
reverse_grad=True) reverse_grad=True)
@ -538,7 +525,7 @@ class BidirectionalConformer(nn.Module):
def reverse_decoder_forward( def reverse_decoder_forward(
self, self,
bn_memory_shifted: torch.Tensor, positive_embed_shifted: torch.Tensor,
memory_key_padding_mask: torch.Tensor, memory_key_padding_mask: torch.Tensor,
sampled: torch.Tensor, sampled: torch.Tensor,
softmax: Optional[torch.Tensor], softmax: Optional[torch.Tensor],
@ -553,15 +540,14 @@ class BidirectionalConformer(nn.Module):
supervision word-sequence. supervision word-sequence.
Args: Args:
bn_memory_shifted: positive_embed_shifted:
It's the bn_memory output of forward(), with shape [T, N, E], shifted It's the positive_embed_shifted output of self.sample_forward(), with
by one so that shifted_memory[t] == bn_memory[t-1], as in: shape [T, N, E]
(T, N, E) = memory.shape
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from the encoder. The padding mask from the encoder.
sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes` sampled: is a Tensor of shape [T, N, C] where C corresponds to
as given to the constructor. This will be needed for the 'reverse' model. `discretization_tot_classes` as given to the constructor.
This will be needed for the 'reverse' model.
softmax: is a "soft" version of `sampled`; if None, will default to `sampled`. softmax: is a "soft" version of `sampled`; if None, will default to `sampled`.
token_ids: token_ids:
A list-of-list IDs. Each sublist contains IDs for an utterance. A list-of-list IDs. Each sublist contains IDs for an utterance.
@ -583,7 +569,7 @@ class BidirectionalConformer(nn.Module):
token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ] token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ]
tokens_padded = pad_sequence(token_ids_tensors, batch_first=True, tokens_padded = pad_sequence(token_ids_tensors, batch_first=True,
padding_value=padding_id).to(bn_memory_shifted.device) padding_value=padding_id).to(positive_embed_shifted.device)
print("tokens_padded = ", tokens_padded) print("tokens_padded = ", tokens_padded)
tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id) tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id)
@ -601,17 +587,17 @@ class BidirectionalConformer(nn.Module):
src_key_padding_mask=tokens_key_padding_mask) src_key_padding_mask=tokens_key_padding_mask)
# token_memory is of shape (S, N, C), if S is length of token sequence. # token_memory is of shape (S, N, C), if S is length of token sequence.
T = bn_memory_shifted.shape[0] T = positive_embed_shifted.shape[0]
# the targets, here, are the hidden discrete symbols we are predicting # the targets, here, are the hidden discrete symbols we are predicting
tgt_mask = generate_square_subsequent_mask(T, device=bn_memory_shifted.device) tgt_mask = generate_square_subsequent_mask(T, device=positive_embed_shifted.device)
hidden_predictor = self.reverse_decoder( hidden_predictor = self.reverse_decoder(
tgt=bn_memory_shifted, tgt=positive_embed_shifted,
memory=token_memory, memory=token_memory,
tgt_mask=tgt_mask, tgt_mask=tgt_mask,
memory_key_padding_mask=tokens_key_padding_mask) memory_key_padding_mask=tokens_key_padding_mask)
total_prob = self.discrete_bottleneck.compute_prob( total_prob = self.sample_and_predict.compute_prob(
hidden_predictor, hidden_predictor,
sampled, sampled,
softmax, softmax,
@ -676,6 +662,9 @@ class ReverseGrad(torch.autograd.Function):
def backward(ctx, x_grad): def backward(ctx, x_grad):
return -x_grad return -x_grad
def reverse_gradient(x: Tensor) -> Tensor:
return ReverseGrad.apply(x)
class DebugGrad(torch.autograd.Function): class DebugGrad(torch.autograd.Function):
@staticmethod @staticmethod
@ -693,19 +682,21 @@ class DebugGrad(torch.autograd.Function):
class DiscreteBottleneck(nn.Module): class SampleAndPredict(nn.Module):
""" """This module discretizes its input and lets you predict the
This layer forces its input through an information bottleneck via discrete classes.
a discretization operation with sampling, and allows you to We use the torch-flow-sampling package for this, to provide a differentiable
predict the likelihood of those discretized values. softmax that should be much better than Gumbel in terms of actually giving
We use the torch-flow-sampling an information bottleneck.
package for this, to provide a differentiable softmax that should be (However, if straight_through_scale == 1.0, which actually seems
much better than Gumbel in terms of actually giving an information to be working fine so far, it's the same as just sampling from the
bottleneck. categorical distribution and using straight-through derivatives
(i.e. the derivatives are as if that output had just been a softmax).
This may depend somewhat on the model; straight_through_scale == 0.0
is definitely safer from a correctness point of view.
Args: Args:
dim: The input and output dimension of the discrete bottleneck dim: The input feature dimension
operation.
tot_classes: The total number of classes (across all groups tot_classes: The total number of classes (across all groups
of classes); each group is separately discretized of classes); each group is separately discretized
num_groups: The number of groups of classes; discretization num_groups: The number of groups of classes; discretization
@ -728,10 +719,10 @@ class DiscreteBottleneck(nn.Module):
1.0 - straight_through_scale. 1.0 - straight_through_scale.
min_prob_ratio: For any class whose average softmax min_prob_ratio: For any class whose average softmax
output, for a given minibatch, is less than output, for a given minibatch, is less than
min_prob_ratio times min_prob_ratio times the average probability,
include_predictor: If true, include the parameters boost its probability; this is a mechanism
necessary to predict the likelihoods of the to avoid "losing" classes, we are hoping it won't really
classes from some kind of input embedding. be necessary in practice.
""" """
def __init__( def __init__(
self, self,
@ -739,12 +730,11 @@ class DiscreteBottleneck(nn.Module):
tot_classes: int, tot_classes: int,
num_groups: int, num_groups: int,
interp_prob: float = 1.0, interp_prob: float = 1.0,
straight_through_scale: float = 0.333, straight_through_scale: float = 0.0,
min_prob_ratio: float = 0.1, min_prob_ratio: float = 0.1,
include_predictor: bool = True
): ):
super(DiscreteBottleneck, self).__init__() super(SampleAndPredict, self).__init__()
self.norm_in = nn.LayerNorm(dim)
self.linear1 = nn.Linear(dim, tot_classes) self.linear1 = nn.Linear(dim, tot_classes)
self.num_groups = num_groups self.num_groups = num_groups
@ -753,25 +743,24 @@ class DiscreteBottleneck(nn.Module):
self.min_prob_ratio = min_prob_ratio self.min_prob_ratio = min_prob_ratio
self.tot_classes = tot_classes self.tot_classes = tot_classes
self.classes_per_group = tot_classes // num_groups self.classes_per_group = tot_classes // num_groups
# prob_boost relates to the min_prob_ratio setting. It's not configurable for now.
self.prob_boost = 1.0e-05 self.prob_boost = 1.0e-05
# class_probs is a rolling mean of the output of the sampling operation. # class_probs is a rolling mean of the output of the sampling operation.
# When any element of it gets below self.min_prob_ratio / self.classes_per_group, # When any element of it gets below self.min_prob_ratio / self.classes_per_group,
# we boost the class's probability by adding self.prob_boost to # we boost the class's probability by adding self.prob_boost to
# that element of self.class_offset # that element of self.class_offset
self.class_probs_decay = 0.9 self.class_probs_decay = 0.95
self.register_buffer('class_probs', torch.ones(tot_classes) / self.classes_per_group) self.register_buffer('class_probs', torch.ones(tot_classes) / self.classes_per_group)
# class_offsets is a bias term that we add to logits before the sampling # class_offsets is a bias term that we add to logits before the sampling
# operation in order to enforce that no class is too infrequent # operation in order to enforce that no class is too infrequent
# (c.f. 'min_prob_ratio'). # (c.f. 'min_prob_ratio').
self.register_buffer('class_offsets', torch.zeros(tot_classes)) self.register_buffer('class_offsets', torch.zeros(tot_classes))
self.linear2 = nn.Linear(tot_classes, dim, bias=False)
self.norm_out = nn.LayerNorm(dim)
if include_predictor:
# pred_linear predicts the class probabilities from a predictor # pred_linear predicts the class probabilities from a predictor
# embedding. # embedding of dimension 'dim' supplied by the user.
self.pred_linear = nn.Linear(dim, tot_classes) self.pred_linear = nn.Linear(dim, tot_classes)
if self.num_groups > 1: if self.num_groups > 1:
@ -790,7 +779,13 @@ class DiscreteBottleneck(nn.Module):
# [ True, True, True, True], # [ True, True, True, True],
# [ True, True, True, True]]) # [ True, True, True, True]])
self.register_buffer('pred_cross_mask', self.register_buffer('pred_cross_mask',
((torch.arange(d) // c).unsqueeze(1) >= (torch.arange(d) // c).unsqueeze(0))) ((torch.arange(d) // c).unsqueeze(1) >=
(torch.arange(d) // c).unsqueeze(0)))
# linear2 and post_layer_norm come after the sampling.
self.linear2 = nn.Linear(tot_classes, dim)
self.post_layer_norm = nn.LayerNorm(dim)
self._reset_parameters() self._reset_parameters()
def _reset_parameters(self): def _reset_parameters(self):
@ -798,17 +793,14 @@ class DiscreteBottleneck(nn.Module):
torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5))
def forward(self, x: Tensor, need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor]: def forward(self, x: Tensor, need_softmax: bool = True) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]:
""" """
Forward computation. Forward computation. See also compute_prob().
Args: Args:
x: The input tensor, of shape (S, N, E) where S is the sequence length, x: The input tensor, of shape (S, N, E) where S is the sequence length,
N is the batch size and E is the embedding dim. N is the batch size and E is the embedding dim.
Returns (embeddding, sampled, softmax), where: Returns (sampled, softmax, positive_embed, negative_embed), where:
embedding: of shape (S, N, E) where E is the embedding dimension (`dim` arg
to the constructor), this is the output embedding; it is projected from the
sampled class probabilities.
sampled: of shape (S, N, C) where C is the `tot_classes` to the sampled: of shape (S, N, C) where C is the `tot_classes` to the
constructor, these are the sampled one-hot vectors or interpolations constructor, these are the sampled one-hot vectors or interpolations
thereof. They will be needed if we try to predict the discrete values thereof. They will be needed if we try to predict the discrete values
@ -820,19 +812,31 @@ class DiscreteBottleneck(nn.Module):
expectation of the result of sampling -> lower-variance derivatives. expectation of the result of sampling -> lower-variance derivatives.
This is unnecessary if straight_through_scale == 1.0, since in that This is unnecessary if straight_through_scale == 1.0, since in that
case it would not affect the backpropagated derivatives. case it would not affect the backpropagated derivatives.
positive_embed: The samples projected back down to the embedding
dimension, and layer-normed (`dim` passed to the constructor).
negative_embed: This is numerically the same value as positive_embed,
but has its gradients reversed prior to the projection and
LayerNorm. This is intended to be used for terms that appear
with opposite sign in the loss function, to be fed to
something whose gradient is already (going to be) reversed:
specifically, the self-prediction network.
""" """
x = self.norm_in(x) * 5 # * 5 gives lower entropy.. starts training faster.. x = self.linear1(x * 5) # multiplying 5 gives lower entropy, makes it
x = self.linear1(x) # begin training faster..
if self.min_prob_ratio > 0.0:
x = x + self.class_offsets x = x + self.class_offsets
(S, N, tot_classes) = x.shape (S, N, tot_classes) = x.shape
x = x.reshape(S, N, self.num_groups, self.classes_per_group) x = x.reshape(S, N, self.num_groups, self.classes_per_group)
# This is a little wasteful since we already compute the softmax # This is a little wasteful since we already compute the softmax inside
# inside 'flow_sample'. # 'flow_sample'.
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
if random.random() < 0.001: if random.random() < 0.001:
# Some info that's useful for debug.
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group) softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
logsoftmax_temp = (softmax_temp + 1.0e-20).log() logsoftmax_temp = (softmax_temp + 1.0e-20).log()
negentropy = (softmax_temp * logsoftmax_temp).sum(-1).mean() negentropy = (softmax_temp * logsoftmax_temp).sum(-1).mean()
@ -841,8 +845,9 @@ class DiscreteBottleneck(nn.Module):
global_log_softmax = (global_softmax + 1.0e-20).log() global_log_softmax = (global_softmax + 1.0e-20).log()
global_negentropy = (global_softmax * global_log_softmax).sum(-1).mean() global_negentropy = (global_softmax * global_log_softmax).sum(-1).mean()
print("Entropy = ", -negentropy.to('cpu').item(), ", averaged entropy = ", -global_negentropy.to('cpu').item()) print("SampleAndPredict: entropy = ",
-negentropy.to('cpu').item(), ", averaged entropy = ",
-global_negentropy.to('cpu').item())
x = torch_flow_sampling.flow_sample(x, x = torch_flow_sampling.flow_sample(x,
@ -854,19 +859,22 @@ class DiscreteBottleneck(nn.Module):
sampled = x sampled = x
if self.training and False: if self.training and self.min_prob_ratio > 0.0:
mean_class_probs = torch.mean(x.detach(), dim=(0,1)) mean_class_probs = torch.mean(x.detach(), dim=(0,1))
self.class_probs = (self.class_probs * self.class_probs_decay + self.class_probs = (self.class_probs * self.class_probs_decay +
mean_class_probs * (1.0 - self.class_probs_decay)) mean_class_probs * (1.0 - self.class_probs_decay))
prob_floor = self.min_prob_ratio / self.classes_per_group prob_floor = self.min_prob_ratio / self.classes_per_group
self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost
embedding = self.norm_out(self.linear2(x))
#if random.random() < 0.01: positive_embed = self.post_layer_norm(self.linear2(sampled))
# return (embedding, DebugGrad.apply(sampled, "sampled"), DebugGrad.apply(softmax, "softmax")) negative_embed = self.post_layer_norm(self.linear2(reverse_gradient(sampled)))
#else:
return (embedding, sampled, softmax) if random.random() < 0.002:
return (DebugGrad.apply(sampled, "sampled"), DebugGrad.apply(softmax, "softmax"),
positive_embed, negative_embed)
else:
return (sampled, softmax, positive_embed, negative_embed)
def compute_prob(self, x: Tensor, sampled: Tensor, softmax: Optional[Tensor], def compute_prob(self, x: Tensor, sampled: Tensor, softmax: Optional[Tensor],
@ -891,23 +899,23 @@ class DiscreteBottleneck(nn.Module):
padding_mask: Optionally, a boolean tensor of shape (N, S), i.e. padding_mask: Optionally, a boolean tensor of shape (N, S), i.e.
(batch_size, sequence_length), with True in masked positions (batch_size, sequence_length), with True in masked positions
that are to be ignored in the sum of probabilities. that are to be ignored in the sum of probabilities.
reverse_grad: If true, negate the gradient that is passed back reverse_grad: If true, negate the gradient that is passed back
to 'x' and to the modules self.pred_linear and pred_cross. to 'x' and to the modules self.pred_linear and pred_cross.
This will be useful in computing a loss function that has This will be useful in computing a loss function that has
a likelihood term with negative sign (i.e. the self-prediction). a likelihood term with negative sign (i.e. the self-prediction).
We'll later need negate the gradient one more more time We'll later need to negate this gradient one more more time
where we give the input to the prediction module that (it's expected, when reverse_grad == True, that x would
generated 'x'. derive somehow from `negative_embed`, so that the gradient
will eventually go back to the correct sign.)
Returns a scalar Tensor represnting the total probability. Returns a scalar Tensor represnting the total probability.
""" """
if reverse_grad: if reverse_grad:
sampled = ReverseGrad.apply(sampled) sampled = reverse_gradient(sampled)
if softmax is None: if softmax is None:
softmax = sampled softmax = sampled
elif reverse_grad: elif reverse_grad:
softmax = ReverseGrad.apply(softmax) softmax = reverse_gradient(softmax)
logprobs = self.pred_linear(x) logprobs = self.pred_linear(x)
@ -942,7 +950,7 @@ class DiscreteBottleneck(nn.Module):
tot_prob = (logprobs * softmax).sum() tot_prob = (logprobs * softmax).sum()
if reverse_grad: if reverse_grad:
tot_prob = ReverseGrad.apply(tot_prob) tot_prob = reverse_gradient(tot_prob)
return tot_prob return tot_prob
@ -1814,28 +1822,27 @@ def _test_bidirectional_conformer():
print("tokens = ", tokens) print("tokens = ", tokens)
print("supervision = ", supervision) print("supervision = ", supervision)
# memory: [T, N, C] # memory: [T, N, C]
(memory, bn_memory, pos_emb, sampled, softmax, key_padding_mask) = m(feats, supervision) (memory, pos_emb, key_padding_mask) = m(feats, supervision)
# ctc_output: [N, T, C]. # ctc_output: [N, T, C].
ctc_output = m.ctc_encoder_forward(memory, pos_emb, key_padding_mask) ctc_output = m.ctc_encoder_forward(memory, pos_emb, key_padding_mask)
(sampled, softmax, positive_embed_shifted, negative_embed_shifted) = m.sample_forward(memory)
decoder_logprob = m.decoder_forward(memory, key_padding_mask, tokens, decoder_logprob = m.decoder_forward(memory, key_padding_mask, tokens,
sos_id=1, sos_id=1,
eos_id=2) eos_id=2)
print("decoder logprob = ", decoder_logprob) print("decoder logprob = ", decoder_logprob)
(T, N, E) = bn_memory.shape
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
reverse_decoder_logprob = m.reverse_decoder_forward( reverse_decoder_logprob = m.reverse_decoder_forward(
bn_memory_shifted, key_padding_mask, positive_embed_shifted, key_padding_mask,
sampled, softmax, tokens, sampled, softmax, tokens,
sos_id=1, eos_id=2, padding_id=0) sos_id=1, eos_id=2, padding_id=0)
print("reverse decoder logprob = ", reverse_decoder_logprob) print("reverse decoder logprob = ", reverse_decoder_logprob)
self_prediction_logprob = m.self_prediction_forward( self_prediction_logprob = m.self_prediction_forward(
bn_memory_shifted, key_padding_mask, negative_embed_shifted, key_padding_mask,
sampled, softmax) sampled, softmax)
print("self prediction logprob = ", self_prediction_logprob) print("self prediction logprob = ", self_prediction_logprob)
@ -1850,24 +1857,24 @@ def _test_discrete_bottleneck():
tot_classes = 256 tot_classes = 256
num_groups = 8 num_groups = 8
interp_prob = 1.0 interp_prob = 1.0
straight_through_scale = 0.0 # will change straight_through_scale = 1.0 # will change
need_softmax = True need_softmax = True
b = DiscreteBottleneck(dim, tot_classes, num_groups, b = SampleAndPredict(dim, tot_classes, num_groups,
interp_prob, straight_through_scale).to(device) interp_prob, straight_through_scale).to(device)
self_predictor = nn.Linear(tot_classes, dim).to(device)
from_feats_predictor = nn.Linear(dim, dim).to(device) from_feats_predictor = nn.Linear(dim, dim).to(device)
model = nn.ModuleList([b, self_predictor, from_feats_predictor]) from_negative_embed_predictor = nn.Linear(dim, dim).to(device)
model = nn.ModuleList([b, from_feats_predictor, from_negative_embed_predictor])
model.train() model.train()
optim = torch.optim.Adam(params=model.parameters(), optim = torch.optim.Adam(params=model.parameters(),
lr=3.0e-04) lr=3.0e-04)
scale = 0.3 # determines the feature correlation..should be between 0 and 1. scale = 0.5 # determines the feature correlation..should be between 0 and 1.
#https://en.wikipedia.org/wiki/Mutual_information#Linear_correlation, -0.5 log(1 - rho^2).. #https://en.wikipedia.org/wiki/Mutual_information#Linear_correlation, -0.5 log(1 - rho^2)..
# scale corresponds to rho^2, rho being sqrt(scale). # scale corresponds to rho^2, rho being sqrt(scale).
mutual_information = dim * -0.5 * math.log(1.0 - scale) mutual_information = dim * -0.5 * math.log(1.0 - scale)
@ -1886,37 +1893,34 @@ def _test_discrete_bottleneck():
#print(f"norm(feats) ={feats.norm()} vs. norm(feats2) = {feats2.norm()}") #print(f"norm(feats) ={feats.norm()} vs. norm(feats2) = {feats2.norm()}")
bn_memory, sampled, softmax = b(feats) sampled, softmax, positive_embed, negative_embed = b(feats)
E = dim
negative_embed_shifted = torch.cat((torch.zeros(1, N, E).to(device),
negative_embed[:-1,:,:]), dim=0)
positive_embed_shifted = torch.cat((torch.zeros(1, N, E).to(device),
positive_embed[:-1,:,:]), dim=0)
# using feats2 instead of feats will limit the mutual information, # using feats2 instead of feats will limit the mutual information,
# to the MI between feats and feats2, which we computed and printed # to the MI between feats and feats2, which we computed and printed
# above as mutual_information. # above as mutual_information.
predictor = from_feats_predictor(feats2) predictor = from_feats_predictor(feats2)
prob = b.compute_prob(predictor, sampled, softmax) prob = b.compute_prob(predictor, sampled, softmax)
if True: if True:
sampled_reversed = ReverseGrad.apply(sampled) predictor_shifted = from_negative_embed_predictor(negative_embed_shifted)
predictor_reversed = self_predictor(sampled_reversed)
if True: self_prob = b.compute_prob(predictor_shifted, sampled, softmax,
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
predictor_reversed[:-1,:,:]), dim=0)
else:
# skip shifting.. want to see the effect..
predictor_reversed_shifted = predictor_reversed
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
reverse_grad=True) reverse_grad=True)
normalized_self_prob = (self_prob / (T * N)).to('cpu').item() normalized_self_prob = (self_prob / (T * N))
normalized_prob = (prob / (T * N)).to('cpu').item() normalized_prob = (prob / (T * N))
loss = -(prob - self_prob) normalized_loss = -(normalized_prob - normalized_self_prob)
normalized_loss = loss / (T * N)
if i % 200 == 0: if i % 200 == 0:
print(f"Epoch {epoch}, iteration {i}, normalized loss/frame is {-normalized_prob} - {-normalized_self_prob} = {normalized_loss.to('cpu').item()}") print(f"Epoch {epoch}, iteration {i}, normalized loss/frame is {-normalized_prob.to('cpu').item()} - {-normalized_self_prob.to('cpu').item()} = {normalized_loss.to('cpu').item()}")
normalized_loss.backward() normalized_loss.backward()