mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-15 13:02:23 +00:00
Refactor so there is no bottleneck, only prediction
This commit is contained in:
parent
0f29f35a42
commit
b0dd4215fe
@ -71,7 +71,7 @@ class ConformerTrunk(nn.Module):
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
@ -149,8 +149,7 @@ class BidirectionalConformer(nn.Module):
|
||||
num_encoder_layers: Number of encoder layers in the "trunk" that
|
||||
encodes the acoustic features
|
||||
num_ctc_encoder_layers: Number of layers in the CTC encoder
|
||||
that comes after the trunk (and possibly the discrete
|
||||
bottleneck, if bypass_bottleneck == True.
|
||||
that comes after the trunk.
|
||||
These are just conformer encoder layers.
|
||||
num_decoder_layers: Number of layers in the attention decoder;
|
||||
this goes from the trunk to the word-pieces or phones.
|
||||
@ -168,17 +167,14 @@ class BidirectionalConformer(nn.Module):
|
||||
useful function is to prevent "trivial" solutions
|
||||
such as collapse of the distribution to a single symbol,
|
||||
or symbols that are highly correlated across time.
|
||||
bypass_bottleneck: If true, bypass the discrete bottleneck
|
||||
when predicting the CTC output and the decoder
|
||||
that decodes the word-pieces or phones.
|
||||
dropout: Dropout probability
|
||||
cnn_module_kernel: Kernel size in forward conformer layers
|
||||
is_bpe: If false, we'll add one (for EOS) to the number of
|
||||
classes at the output of the decoder
|
||||
use_feat_batchnorm: If true, apply batchnorm to the input features.
|
||||
discrete_bottleneck_tot_classes: Total number of classes
|
||||
discretization_tot_classes: Total number of classes
|
||||
(across all groups) in the discrete bottleneck
|
||||
discrete_bottleneck_num_groups: Number of groups of classes/symbols
|
||||
discretization_num_groups: Number of groups of classes/symbols
|
||||
in the discrete bottleneck
|
||||
"""
|
||||
def __init__(
|
||||
@ -195,18 +191,15 @@ class BidirectionalConformer(nn.Module):
|
||||
num_reverse_encoder_layers: int = 4,
|
||||
num_reverse_decoder_layers: int = 4,
|
||||
num_self_predictor_layers: int = 3,
|
||||
bypass_bottleneck: bool = True,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
is_bpe: bool = False,
|
||||
use_feat_batchnorm: bool = True,
|
||||
discrete_bottleneck_tot_classes: int = 512,
|
||||
discrete_bottleneck_num_groups: int = 4
|
||||
discretization_tot_classes: int = 512,
|
||||
discretization_num_groups: int = 4
|
||||
) -> None:
|
||||
super(BidirectionalConformer, self).__init__()
|
||||
|
||||
self.bypass_bottleneck = bypass_bottleneck
|
||||
|
||||
self.trunk = ConformerTrunk(num_features, subsampling_factor,
|
||||
d_model, nhead, dim_feedforward,
|
||||
num_trunk_encoder_layers, dropout,
|
||||
@ -324,53 +317,64 @@ class BidirectionalConformer(nn.Module):
|
||||
for _ in range(num_self_predictor_layers)])
|
||||
|
||||
|
||||
self.discrete_bottleneck = DiscreteBottleneck(
|
||||
self.sample_and_predict = SampleAndPredict(
|
||||
dim=d_model,
|
||||
tot_classes=discrete_bottleneck_tot_classes,
|
||||
num_groups=discrete_bottleneck_num_groups)
|
||||
tot_classes=discretization_tot_classes,
|
||||
num_groups=discretization_num_groups)
|
||||
|
||||
|
||||
|
||||
def forward(self, x: Tensor, supervision: Optional[Supervisions] = None,
|
||||
need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]:
|
||||
need_softmax: bool = True) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
|
||||
"""
|
||||
Forward function that "encodes" the features.
|
||||
Forward function that "encodes" the acoustic features through the "trunk"
|
||||
(the shared part of the encoding of the encoding of the acoustic features)
|
||||
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is [N, T, F], i.e. [batch_size, num_frames, num_features].
|
||||
supervision:
|
||||
Supervision in lhotse format (optional)
|
||||
Supervision in lhotse format (optional; needed only for acoustic length
|
||||
information)
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
(CAUTION: It contains length information, i.e., start and number of
|
||||
frames, before subsampling). Used only to compute masking information.
|
||||
need_softmax:
|
||||
If true, the last output ("softmax") will be computed. This can be useful
|
||||
in the reverse model, but only necessary if straight_through_scale != 1.0.
|
||||
|
||||
Returns: (memory, bn_memory, pos_emb, sampled, softmax, key_padding_mask), where:
|
||||
Returns: (memory, pos_emb, key_padding_mask), where:
|
||||
|
||||
memory: a Tensor of shape [T, N, E] i.e. [T, batch_size, embedding_dim] where T
|
||||
is actually a subsampled form of the num_frames of the input `x`.
|
||||
If self.bypass_bottleneck, it will be taken before the discrete
|
||||
bottleneck; otherwise, from after.
|
||||
bn_memory: The same shape as `memory`, but comes after the discrete bottleneck
|
||||
regardless of the value of self.bypass_bottleneck.
|
||||
pos_emb: The relative positional embedding; will be given to ctc_encoder_forward()
|
||||
sampled: a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes`
|
||||
as given to the constructor. This will be needed for the 'reverse' model.
|
||||
softmax: a "soft" version of `sampled`. Will only be returned if need_softmax == True;
|
||||
else will be None.
|
||||
key_padding_mask: The padding mask for the "memory" output, a Tensor of bool of
|
||||
shape [N, T] (only if supervision was supplied, else None).
|
||||
"""
|
||||
encoder_output, pos_emb, memory_key_padding_mask = self.trunk(x, supervision)
|
||||
memory, pos_emb, memory_key_padding_mask = self.trunk(x, supervision)
|
||||
return memory, pos_emb, memory_key_padding_mask
|
||||
|
||||
bn_memory, sampled, softmax = self.discrete_bottleneck(encoder_output)
|
||||
|
||||
memory = encoder_output if self.bypass_bottleneck else bn_memory
|
||||
def sample_forward(self, memory: Tensor) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]:
|
||||
"""
|
||||
Given the "memory" from forward(), run the sample_and_redict module.
|
||||
See documentation for forward() of class SampleAndPredict for more info.
|
||||
|
||||
return (memory, bn_memory, pos_emb, sampled, softmax, memory_key_padding_mask)
|
||||
Returns (sampled, softmax, positive_embed_shifted, negative_embed_shifted),
|
||||
where positive_embed_shifted, for instance, is positive_embed
|
||||
shifted by one so that positive_embed_shifted[t] == positive_embed[t-1], as in:
|
||||
(T, N, E) = positive_embed.shape
|
||||
positive_embed_shifted = torch.cat((torch.zeros(1, N, E), positive_embed[:-1,:,:]), dim=0)
|
||||
|
||||
"""
|
||||
(sampled, softmax, positive_embed, negative_embed) = self.sample_and_predict(memory)
|
||||
|
||||
(T, N, E) = memory.shape
|
||||
device = memory.device
|
||||
zeros = torch.zeros(1, N, E).to(memory.device)
|
||||
negative_embed_shifted = torch.cat((zeros, negative_embed[:-1,:,:]), dim=0)
|
||||
positive_embed_shifted = torch.cat((zeros, positive_embed[:-1,:,:]), dim=0)
|
||||
|
||||
return (sampled, softmax, positive_embed_shifted, negative_embed_shifted)
|
||||
|
||||
def decoder_forward(
|
||||
self,
|
||||
@ -479,11 +483,10 @@ class BidirectionalConformer(nn.Module):
|
||||
|
||||
def self_prediction_forward(
|
||||
self,
|
||||
bn_memory_shifted: torch.Tensor,
|
||||
negative_embed_shifted: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
sampled: torch.Tensor,
|
||||
softmax: Optional[torch.Tensor],
|
||||
reverse_grad: bool = True) -> Tensor:
|
||||
softmax: Optional[torch.Tensor]) -> Tensor:
|
||||
"""
|
||||
Returns the total log-prob of the the labels sampled in the discrete
|
||||
bottleneck layer, as predicted using a relatively simple model that
|
||||
@ -491,45 +494,29 @@ class BidirectionalConformer(nn.Module):
|
||||
[Appears on the denominator of an expression for mutual information].
|
||||
|
||||
Args:
|
||||
bn_memory_shifted:
|
||||
It's the bn_memory output of forward(), with shape [T, N, E], shifted
|
||||
by one so that bn_shifted_memory[t] == bn_memory[t-1], as in:
|
||||
(T, N, E) = bn_memory.shape
|
||||
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
|
||||
negative_embed_shifted:
|
||||
The negative_embed_shifted output of self.sample_forward(), with shape [T, N, E]
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder, of shape [N, T], boolean, True
|
||||
for masked locations.
|
||||
sampled: sampled and interpolated one-hot values, as a Tensor of shape [T, N, C]
|
||||
where C corresponds to `discrete_bottleneck_tot_classes`
|
||||
where C corresponds to `discretization_tot_classes`
|
||||
as given to the constructor. This will be needed for the 'reverse'
|
||||
model.
|
||||
softmax: is a "soft" version of `sampled`; if None, will default to `sampled`.
|
||||
reverse_grad: will likely be true. If true, the gradient is reversed twice
|
||||
in this computation, so that we train predictors with the correct
|
||||
gradient, i.e. to predict, not anti-predict (since the return value
|
||||
of this function will appear with positive, not negative, sign in the
|
||||
loss function, so will be minimized).
|
||||
The gradient w.r.t. the non-self inputs to this function, though (i.e.
|
||||
bn_memory_shifted, sampled, softmax) will not be reversed, though.
|
||||
|
||||
Returns:
|
||||
A scalar tensor, the **sum** of label smoothing loss over utterances
|
||||
A scalar tensor, the **sum** of the log-prob loss over utterances
|
||||
in the batch without any normalization.
|
||||
"""
|
||||
|
||||
if reverse_grad:
|
||||
# Reversing gradient for bn_memory_shifted puts the gradient back into
|
||||
# the correct sign; we reversed it in
|
||||
# self.discrete_bottleneck.compute_prob(), in order that
|
||||
# self.self.predictor_encoder can learn to predict.
|
||||
# (You have to read the code in reverse, to reason about
|
||||
# what happens to the gradients).
|
||||
bn_memory_shifted = ReverseGrad.apply(bn_memory_shifted)
|
||||
|
||||
# no mask is needed for self_predictor_encoder; its CNN
|
||||
# layer uses left-padding only, making it causal.
|
||||
predictor = self.self_predictor_encoder(bn_memory_shifted)
|
||||
# layer uses left-padding only, making it causal, so the mask
|
||||
# is redundant (it wouldn't affect any of the
|
||||
# outputs we care about).
|
||||
predictor = self.self_predictor_encoder(negative_embed_shifted)
|
||||
|
||||
prob = self.discrete_bottleneck.compute_prob(predictor,
|
||||
prob = self.sample_and_predict.compute_prob(predictor,
|
||||
sampled, softmax,
|
||||
memory_key_padding_mask,
|
||||
reverse_grad=True)
|
||||
@ -538,7 +525,7 @@ class BidirectionalConformer(nn.Module):
|
||||
|
||||
def reverse_decoder_forward(
|
||||
self,
|
||||
bn_memory_shifted: torch.Tensor,
|
||||
positive_embed_shifted: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
sampled: torch.Tensor,
|
||||
softmax: Optional[torch.Tensor],
|
||||
@ -553,15 +540,14 @@ class BidirectionalConformer(nn.Module):
|
||||
supervision word-sequence.
|
||||
|
||||
Args:
|
||||
bn_memory_shifted:
|
||||
It's the bn_memory output of forward(), with shape [T, N, E], shifted
|
||||
by one so that shifted_memory[t] == bn_memory[t-1], as in:
|
||||
(T, N, E) = memory.shape
|
||||
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
|
||||
positive_embed_shifted:
|
||||
It's the positive_embed_shifted output of self.sample_forward(), with
|
||||
shape [T, N, E]
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder.
|
||||
sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes`
|
||||
as given to the constructor. This will be needed for the 'reverse' model.
|
||||
sampled: is a Tensor of shape [T, N, C] where C corresponds to
|
||||
`discretization_tot_classes` as given to the constructor.
|
||||
This will be needed for the 'reverse' model.
|
||||
softmax: is a "soft" version of `sampled`; if None, will default to `sampled`.
|
||||
token_ids:
|
||||
A list-of-list IDs. Each sublist contains IDs for an utterance.
|
||||
@ -583,7 +569,7 @@ class BidirectionalConformer(nn.Module):
|
||||
token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ]
|
||||
|
||||
tokens_padded = pad_sequence(token_ids_tensors, batch_first=True,
|
||||
padding_value=padding_id).to(bn_memory_shifted.device)
|
||||
padding_value=padding_id).to(positive_embed_shifted.device)
|
||||
|
||||
print("tokens_padded = ", tokens_padded)
|
||||
tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id)
|
||||
@ -601,17 +587,17 @@ class BidirectionalConformer(nn.Module):
|
||||
src_key_padding_mask=tokens_key_padding_mask)
|
||||
# token_memory is of shape (S, N, C), if S is length of token sequence.
|
||||
|
||||
T = bn_memory_shifted.shape[0]
|
||||
T = positive_embed_shifted.shape[0]
|
||||
# the targets, here, are the hidden discrete symbols we are predicting
|
||||
tgt_mask = generate_square_subsequent_mask(T, device=bn_memory_shifted.device)
|
||||
tgt_mask = generate_square_subsequent_mask(T, device=positive_embed_shifted.device)
|
||||
|
||||
hidden_predictor = self.reverse_decoder(
|
||||
tgt=bn_memory_shifted,
|
||||
tgt=positive_embed_shifted,
|
||||
memory=token_memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_key_padding_mask=tokens_key_padding_mask)
|
||||
|
||||
total_prob = self.discrete_bottleneck.compute_prob(
|
||||
total_prob = self.sample_and_predict.compute_prob(
|
||||
hidden_predictor,
|
||||
sampled,
|
||||
softmax,
|
||||
@ -676,6 +662,9 @@ class ReverseGrad(torch.autograd.Function):
|
||||
def backward(ctx, x_grad):
|
||||
return -x_grad
|
||||
|
||||
def reverse_gradient(x: Tensor) -> Tensor:
|
||||
return ReverseGrad.apply(x)
|
||||
|
||||
|
||||
class DebugGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ -693,19 +682,21 @@ class DebugGrad(torch.autograd.Function):
|
||||
|
||||
|
||||
|
||||
class DiscreteBottleneck(nn.Module):
|
||||
"""
|
||||
This layer forces its input through an information bottleneck via
|
||||
a discretization operation with sampling, and allows you to
|
||||
predict the likelihood of those discretized values.
|
||||
We use the torch-flow-sampling
|
||||
package for this, to provide a differentiable softmax that should be
|
||||
much better than Gumbel in terms of actually giving an information
|
||||
bottleneck.
|
||||
class SampleAndPredict(nn.Module):
|
||||
"""This module discretizes its input and lets you predict the
|
||||
discrete classes.
|
||||
We use the torch-flow-sampling package for this, to provide a differentiable
|
||||
softmax that should be much better than Gumbel in terms of actually giving
|
||||
an information bottleneck.
|
||||
(However, if straight_through_scale == 1.0, which actually seems
|
||||
to be working fine so far, it's the same as just sampling from the
|
||||
categorical distribution and using straight-through derivatives
|
||||
(i.e. the derivatives are as if that output had just been a softmax).
|
||||
This may depend somewhat on the model; straight_through_scale == 0.0
|
||||
is definitely safer from a correctness point of view.
|
||||
|
||||
Args:
|
||||
dim: The input and output dimension of the discrete bottleneck
|
||||
operation.
|
||||
dim: The input feature dimension
|
||||
tot_classes: The total number of classes (across all groups
|
||||
of classes); each group is separately discretized
|
||||
num_groups: The number of groups of classes; discretization
|
||||
@ -728,10 +719,10 @@ class DiscreteBottleneck(nn.Module):
|
||||
1.0 - straight_through_scale.
|
||||
min_prob_ratio: For any class whose average softmax
|
||||
output, for a given minibatch, is less than
|
||||
min_prob_ratio times
|
||||
include_predictor: If true, include the parameters
|
||||
necessary to predict the likelihoods of the
|
||||
classes from some kind of input embedding.
|
||||
min_prob_ratio times the average probability,
|
||||
boost its probability; this is a mechanism
|
||||
to avoid "losing" classes, we are hoping it won't really
|
||||
be necessary in practice.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@ -739,12 +730,11 @@ class DiscreteBottleneck(nn.Module):
|
||||
tot_classes: int,
|
||||
num_groups: int,
|
||||
interp_prob: float = 1.0,
|
||||
straight_through_scale: float = 0.333,
|
||||
straight_through_scale: float = 0.0,
|
||||
min_prob_ratio: float = 0.1,
|
||||
include_predictor: bool = True
|
||||
):
|
||||
super(DiscreteBottleneck, self).__init__()
|
||||
self.norm_in = nn.LayerNorm(dim)
|
||||
super(SampleAndPredict, self).__init__()
|
||||
|
||||
self.linear1 = nn.Linear(dim, tot_classes)
|
||||
|
||||
self.num_groups = num_groups
|
||||
@ -753,25 +743,24 @@ class DiscreteBottleneck(nn.Module):
|
||||
self.min_prob_ratio = min_prob_ratio
|
||||
self.tot_classes = tot_classes
|
||||
self.classes_per_group = tot_classes // num_groups
|
||||
|
||||
# prob_boost relates to the min_prob_ratio setting. It's not configurable for now.
|
||||
self.prob_boost = 1.0e-05
|
||||
|
||||
# class_probs is a rolling mean of the output of the sampling operation.
|
||||
# When any element of it gets below self.min_prob_ratio / self.classes_per_group,
|
||||
# we boost the class's probability by adding self.prob_boost to
|
||||
# that element of self.class_offset
|
||||
self.class_probs_decay = 0.9
|
||||
self.class_probs_decay = 0.95
|
||||
self.register_buffer('class_probs', torch.ones(tot_classes) / self.classes_per_group)
|
||||
# class_offsets is a bias term that we add to logits before the sampling
|
||||
# operation in order to enforce that no class is too infrequent
|
||||
# (c.f. 'min_prob_ratio').
|
||||
self.register_buffer('class_offsets', torch.zeros(tot_classes))
|
||||
|
||||
self.linear2 = nn.Linear(tot_classes, dim, bias=False)
|
||||
self.norm_out = nn.LayerNorm(dim)
|
||||
|
||||
if include_predictor:
|
||||
# pred_linear predicts the class probabilities from a predictor
|
||||
# embedding.
|
||||
# embedding of dimension 'dim' supplied by the user.
|
||||
self.pred_linear = nn.Linear(dim, tot_classes)
|
||||
|
||||
if self.num_groups > 1:
|
||||
@ -790,7 +779,13 @@ class DiscreteBottleneck(nn.Module):
|
||||
# [ True, True, True, True],
|
||||
# [ True, True, True, True]])
|
||||
self.register_buffer('pred_cross_mask',
|
||||
((torch.arange(d) // c).unsqueeze(1) >= (torch.arange(d) // c).unsqueeze(0)))
|
||||
((torch.arange(d) // c).unsqueeze(1) >=
|
||||
(torch.arange(d) // c).unsqueeze(0)))
|
||||
|
||||
# linear2 and post_layer_norm come after the sampling.
|
||||
self.linear2 = nn.Linear(tot_classes, dim)
|
||||
self.post_layer_norm = nn.LayerNorm(dim)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
@ -798,17 +793,14 @@ class DiscreteBottleneck(nn.Module):
|
||||
torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5))
|
||||
|
||||
|
||||
def forward(self, x: Tensor, need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
def forward(self, x: Tensor, need_softmax: bool = True) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]:
|
||||
"""
|
||||
Forward computation.
|
||||
Forward computation. See also compute_prob().
|
||||
Args:
|
||||
x: The input tensor, of shape (S, N, E) where S is the sequence length,
|
||||
N is the batch size and E is the embedding dim.
|
||||
Returns (embeddding, sampled, softmax), where:
|
||||
Returns (sampled, softmax, positive_embed, negative_embed), where:
|
||||
|
||||
embedding: of shape (S, N, E) where E is the embedding dimension (`dim` arg
|
||||
to the constructor), this is the output embedding; it is projected from the
|
||||
sampled class probabilities.
|
||||
sampled: of shape (S, N, C) where C is the `tot_classes` to the
|
||||
constructor, these are the sampled one-hot vectors or interpolations
|
||||
thereof. They will be needed if we try to predict the discrete values
|
||||
@ -820,19 +812,31 @@ class DiscreteBottleneck(nn.Module):
|
||||
expectation of the result of sampling -> lower-variance derivatives.
|
||||
This is unnecessary if straight_through_scale == 1.0, since in that
|
||||
case it would not affect the backpropagated derivatives.
|
||||
positive_embed: The samples projected back down to the embedding
|
||||
dimension, and layer-normed (`dim` passed to the constructor).
|
||||
negative_embed: This is numerically the same value as positive_embed,
|
||||
but has its gradients reversed prior to the projection and
|
||||
LayerNorm. This is intended to be used for terms that appear
|
||||
with opposite sign in the loss function, to be fed to
|
||||
something whose gradient is already (going to be) reversed:
|
||||
specifically, the self-prediction network.
|
||||
"""
|
||||
x = self.norm_in(x) * 5 # * 5 gives lower entropy.. starts training faster..
|
||||
x = self.linear1(x)
|
||||
x = self.linear1(x * 5) # multiplying 5 gives lower entropy, makes it
|
||||
# begin training faster..
|
||||
|
||||
if self.min_prob_ratio > 0.0:
|
||||
x = x + self.class_offsets
|
||||
|
||||
(S, N, tot_classes) = x.shape
|
||||
|
||||
x = x.reshape(S, N, self.num_groups, self.classes_per_group)
|
||||
|
||||
# This is a little wasteful since we already compute the softmax
|
||||
# inside 'flow_sample'.
|
||||
# This is a little wasteful since we already compute the softmax inside
|
||||
# 'flow_sample'.
|
||||
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
|
||||
|
||||
if random.random() < 0.001:
|
||||
# Some info that's useful for debug.
|
||||
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
|
||||
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
|
||||
negentropy = (softmax_temp * logsoftmax_temp).sum(-1).mean()
|
||||
@ -841,8 +845,9 @@ class DiscreteBottleneck(nn.Module):
|
||||
global_log_softmax = (global_softmax + 1.0e-20).log()
|
||||
global_negentropy = (global_softmax * global_log_softmax).sum(-1).mean()
|
||||
|
||||
print("Entropy = ", -negentropy.to('cpu').item(), ", averaged entropy = ", -global_negentropy.to('cpu').item())
|
||||
|
||||
print("SampleAndPredict: entropy = ",
|
||||
-negentropy.to('cpu').item(), ", averaged entropy = ",
|
||||
-global_negentropy.to('cpu').item())
|
||||
|
||||
|
||||
x = torch_flow_sampling.flow_sample(x,
|
||||
@ -854,19 +859,22 @@ class DiscreteBottleneck(nn.Module):
|
||||
|
||||
sampled = x
|
||||
|
||||
if self.training and False:
|
||||
if self.training and self.min_prob_ratio > 0.0:
|
||||
mean_class_probs = torch.mean(x.detach(), dim=(0,1))
|
||||
self.class_probs = (self.class_probs * self.class_probs_decay +
|
||||
mean_class_probs * (1.0 - self.class_probs_decay))
|
||||
prob_floor = self.min_prob_ratio / self.classes_per_group
|
||||
self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost
|
||||
|
||||
embedding = self.norm_out(self.linear2(x))
|
||||
|
||||
#if random.random() < 0.01:
|
||||
# return (embedding, DebugGrad.apply(sampled, "sampled"), DebugGrad.apply(softmax, "softmax"))
|
||||
#else:
|
||||
return (embedding, sampled, softmax)
|
||||
positive_embed = self.post_layer_norm(self.linear2(sampled))
|
||||
negative_embed = self.post_layer_norm(self.linear2(reverse_gradient(sampled)))
|
||||
|
||||
if random.random() < 0.002:
|
||||
return (DebugGrad.apply(sampled, "sampled"), DebugGrad.apply(softmax, "softmax"),
|
||||
positive_embed, negative_embed)
|
||||
else:
|
||||
return (sampled, softmax, positive_embed, negative_embed)
|
||||
|
||||
|
||||
def compute_prob(self, x: Tensor, sampled: Tensor, softmax: Optional[Tensor],
|
||||
@ -891,23 +899,23 @@ class DiscreteBottleneck(nn.Module):
|
||||
padding_mask: Optionally, a boolean tensor of shape (N, S), i.e.
|
||||
(batch_size, sequence_length), with True in masked positions
|
||||
that are to be ignored in the sum of probabilities.
|
||||
|
||||
reverse_grad: If true, negate the gradient that is passed back
|
||||
to 'x' and to the modules self.pred_linear and pred_cross.
|
||||
This will be useful in computing a loss function that has
|
||||
a likelihood term with negative sign (i.e. the self-prediction).
|
||||
We'll later need negate the gradient one more more time
|
||||
where we give the input to the prediction module that
|
||||
generated 'x'.
|
||||
We'll later need to negate this gradient one more more time
|
||||
(it's expected, when reverse_grad == True, that x would
|
||||
derive somehow from `negative_embed`, so that the gradient
|
||||
will eventually go back to the correct sign.)
|
||||
|
||||
Returns a scalar Tensor represnting the total probability.
|
||||
"""
|
||||
if reverse_grad:
|
||||
sampled = ReverseGrad.apply(sampled)
|
||||
sampled = reverse_gradient(sampled)
|
||||
if softmax is None:
|
||||
softmax = sampled
|
||||
elif reverse_grad:
|
||||
softmax = ReverseGrad.apply(softmax)
|
||||
softmax = reverse_gradient(softmax)
|
||||
|
||||
logprobs = self.pred_linear(x)
|
||||
|
||||
@ -942,7 +950,7 @@ class DiscreteBottleneck(nn.Module):
|
||||
tot_prob = (logprobs * softmax).sum()
|
||||
|
||||
if reverse_grad:
|
||||
tot_prob = ReverseGrad.apply(tot_prob)
|
||||
tot_prob = reverse_gradient(tot_prob)
|
||||
return tot_prob
|
||||
|
||||
|
||||
@ -1814,28 +1822,27 @@ def _test_bidirectional_conformer():
|
||||
print("tokens = ", tokens)
|
||||
print("supervision = ", supervision)
|
||||
# memory: [T, N, C]
|
||||
(memory, bn_memory, pos_emb, sampled, softmax, key_padding_mask) = m(feats, supervision)
|
||||
(memory, pos_emb, key_padding_mask) = m(feats, supervision)
|
||||
|
||||
# ctc_output: [N, T, C].
|
||||
ctc_output = m.ctc_encoder_forward(memory, pos_emb, key_padding_mask)
|
||||
|
||||
(sampled, softmax, positive_embed_shifted, negative_embed_shifted) = m.sample_forward(memory)
|
||||
|
||||
decoder_logprob = m.decoder_forward(memory, key_padding_mask, tokens,
|
||||
sos_id=1,
|
||||
eos_id=2)
|
||||
print("decoder logprob = ", decoder_logprob)
|
||||
|
||||
(T, N, E) = bn_memory.shape
|
||||
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
|
||||
|
||||
reverse_decoder_logprob = m.reverse_decoder_forward(
|
||||
bn_memory_shifted, key_padding_mask,
|
||||
positive_embed_shifted, key_padding_mask,
|
||||
sampled, softmax, tokens,
|
||||
sos_id=1, eos_id=2, padding_id=0)
|
||||
|
||||
print("reverse decoder logprob = ", reverse_decoder_logprob)
|
||||
|
||||
self_prediction_logprob = m.self_prediction_forward(
|
||||
bn_memory_shifted, key_padding_mask,
|
||||
negative_embed_shifted, key_padding_mask,
|
||||
sampled, softmax)
|
||||
|
||||
print("self prediction logprob = ", self_prediction_logprob)
|
||||
@ -1850,24 +1857,24 @@ def _test_discrete_bottleneck():
|
||||
tot_classes = 256
|
||||
num_groups = 8
|
||||
interp_prob = 1.0
|
||||
straight_through_scale = 0.0 # will change
|
||||
straight_through_scale = 1.0 # will change
|
||||
need_softmax = True
|
||||
|
||||
b = DiscreteBottleneck(dim, tot_classes, num_groups,
|
||||
b = SampleAndPredict(dim, tot_classes, num_groups,
|
||||
interp_prob, straight_through_scale).to(device)
|
||||
|
||||
self_predictor = nn.Linear(tot_classes, dim).to(device)
|
||||
|
||||
from_feats_predictor = nn.Linear(dim, dim).to(device)
|
||||
|
||||
model = nn.ModuleList([b, self_predictor, from_feats_predictor])
|
||||
from_negative_embed_predictor = nn.Linear(dim, dim).to(device)
|
||||
|
||||
model = nn.ModuleList([b, from_feats_predictor, from_negative_embed_predictor])
|
||||
model.train()
|
||||
|
||||
optim = torch.optim.Adam(params=model.parameters(),
|
||||
lr=3.0e-04)
|
||||
|
||||
|
||||
scale = 0.3 # determines the feature correlation..should be between 0 and 1.
|
||||
scale = 0.5 # determines the feature correlation..should be between 0 and 1.
|
||||
#https://en.wikipedia.org/wiki/Mutual_information#Linear_correlation, -0.5 log(1 - rho^2)..
|
||||
# scale corresponds to rho^2, rho being sqrt(scale).
|
||||
mutual_information = dim * -0.5 * math.log(1.0 - scale)
|
||||
@ -1886,37 +1893,34 @@ def _test_discrete_bottleneck():
|
||||
|
||||
#print(f"norm(feats) ={feats.norm()} vs. norm(feats2) = {feats2.norm()}")
|
||||
|
||||
bn_memory, sampled, softmax = b(feats)
|
||||
sampled, softmax, positive_embed, negative_embed = b(feats)
|
||||
|
||||
E = dim
|
||||
negative_embed_shifted = torch.cat((torch.zeros(1, N, E).to(device),
|
||||
negative_embed[:-1,:,:]), dim=0)
|
||||
positive_embed_shifted = torch.cat((torch.zeros(1, N, E).to(device),
|
||||
positive_embed[:-1,:,:]), dim=0)
|
||||
|
||||
# using feats2 instead of feats will limit the mutual information,
|
||||
# to the MI between feats and feats2, which we computed and printed
|
||||
# above as mutual_information.
|
||||
predictor = from_feats_predictor(feats2)
|
||||
|
||||
prob = b.compute_prob(predictor, sampled, softmax)
|
||||
|
||||
if True:
|
||||
sampled_reversed = ReverseGrad.apply(sampled)
|
||||
predictor_reversed = self_predictor(sampled_reversed)
|
||||
predictor_shifted = from_negative_embed_predictor(negative_embed_shifted)
|
||||
|
||||
if True:
|
||||
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
|
||||
predictor_reversed[:-1,:,:]), dim=0)
|
||||
else:
|
||||
# skip shifting.. want to see the effect..
|
||||
predictor_reversed_shifted = predictor_reversed
|
||||
|
||||
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
|
||||
self_prob = b.compute_prob(predictor_shifted, sampled, softmax,
|
||||
reverse_grad=True)
|
||||
normalized_self_prob = (self_prob / (T * N)).to('cpu').item()
|
||||
normalized_self_prob = (self_prob / (T * N))
|
||||
|
||||
normalized_prob = (prob / (T * N)).to('cpu').item()
|
||||
normalized_prob = (prob / (T * N))
|
||||
|
||||
loss = -(prob - self_prob)
|
||||
|
||||
normalized_loss = loss / (T * N)
|
||||
normalized_loss = -(normalized_prob - normalized_self_prob)
|
||||
|
||||
if i % 200 == 0:
|
||||
print(f"Epoch {epoch}, iteration {i}, normalized loss/frame is {-normalized_prob} - {-normalized_self_prob} = {normalized_loss.to('cpu').item()}")
|
||||
print(f"Epoch {epoch}, iteration {i}, normalized loss/frame is {-normalized_prob.to('cpu').item()} - {-normalized_self_prob.to('cpu').item()} = {normalized_loss.to('cpu').item()}")
|
||||
|
||||
normalized_loss.backward()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user