diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py index f3cd9054b..3a667ce15 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py @@ -71,7 +71,7 @@ class ConformerTrunk(nn.Module): def forward( self, x: torch.Tensor, supervision: Optional[Supervisions] = None - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Args: x: @@ -149,8 +149,7 @@ class BidirectionalConformer(nn.Module): num_encoder_layers: Number of encoder layers in the "trunk" that encodes the acoustic features num_ctc_encoder_layers: Number of layers in the CTC encoder - that comes after the trunk (and possibly the discrete - bottleneck, if bypass_bottleneck == True. + that comes after the trunk. These are just conformer encoder layers. num_decoder_layers: Number of layers in the attention decoder; this goes from the trunk to the word-pieces or phones. @@ -168,17 +167,14 @@ class BidirectionalConformer(nn.Module): useful function is to prevent "trivial" solutions such as collapse of the distribution to a single symbol, or symbols that are highly correlated across time. - bypass_bottleneck: If true, bypass the discrete bottleneck - when predicting the CTC output and the decoder - that decodes the word-pieces or phones. dropout: Dropout probability cnn_module_kernel: Kernel size in forward conformer layers is_bpe: If false, we'll add one (for EOS) to the number of classes at the output of the decoder use_feat_batchnorm: If true, apply batchnorm to the input features. - discrete_bottleneck_tot_classes: Total number of classes + discretization_tot_classes: Total number of classes (across all groups) in the discrete bottleneck - discrete_bottleneck_num_groups: Number of groups of classes/symbols + discretization_num_groups: Number of groups of classes/symbols in the discrete bottleneck """ def __init__( @@ -195,18 +191,15 @@ class BidirectionalConformer(nn.Module): num_reverse_encoder_layers: int = 4, num_reverse_decoder_layers: int = 4, num_self_predictor_layers: int = 3, - bypass_bottleneck: bool = True, dropout: float = 0.1, cnn_module_kernel: int = 31, is_bpe: bool = False, use_feat_batchnorm: bool = True, - discrete_bottleneck_tot_classes: int = 512, - discrete_bottleneck_num_groups: int = 4 + discretization_tot_classes: int = 512, + discretization_num_groups: int = 4 ) -> None: super(BidirectionalConformer, self).__init__() - self.bypass_bottleneck = bypass_bottleneck - self.trunk = ConformerTrunk(num_features, subsampling_factor, d_model, nhead, dim_feedforward, num_trunk_encoder_layers, dropout, @@ -324,53 +317,64 @@ class BidirectionalConformer(nn.Module): for _ in range(num_self_predictor_layers)]) - self.discrete_bottleneck = DiscreteBottleneck( + self.sample_and_predict = SampleAndPredict( dim=d_model, - tot_classes=discrete_bottleneck_tot_classes, - num_groups=discrete_bottleneck_num_groups) + tot_classes=discretization_tot_classes, + num_groups=discretization_num_groups) def forward(self, x: Tensor, supervision: Optional[Supervisions] = None, - need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]: + need_softmax: bool = True) -> Tuple[Tensor, Tensor, Optional[Tensor]]: """ - Forward function that "encodes" the features. + Forward function that "encodes" the acoustic features through the "trunk" + (the shared part of the encoding of the encoding of the acoustic features) Args: x: The input tensor. Its shape is [N, T, F], i.e. [batch_size, num_frames, num_features]. supervision: - Supervision in lhotse format (optional) + Supervision in lhotse format (optional; needed only for acoustic length + information) See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa (CAUTION: It contains length information, i.e., start and number of frames, before subsampling). Used only to compute masking information. - need_softmax: - If true, the last output ("softmax") will be computed. This can be useful - in the reverse model, but only necessary if straight_through_scale != 1.0. - Returns: (memory, bn_memory, pos_emb, sampled, softmax, key_padding_mask), where: + Returns: (memory, pos_emb, key_padding_mask), where: memory: a Tensor of shape [T, N, E] i.e. [T, batch_size, embedding_dim] where T is actually a subsampled form of the num_frames of the input `x`. If self.bypass_bottleneck, it will be taken before the discrete bottleneck; otherwise, from after. - bn_memory: The same shape as `memory`, but comes after the discrete bottleneck - regardless of the value of self.bypass_bottleneck. pos_emb: The relative positional embedding; will be given to ctc_encoder_forward() - sampled: a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes` - as given to the constructor. This will be needed for the 'reverse' model. - softmax: a "soft" version of `sampled`. Will only be returned if need_softmax == True; - else will be None. key_padding_mask: The padding mask for the "memory" output, a Tensor of bool of shape [N, T] (only if supervision was supplied, else None). """ - encoder_output, pos_emb, memory_key_padding_mask = self.trunk(x, supervision) + memory, pos_emb, memory_key_padding_mask = self.trunk(x, supervision) + return memory, pos_emb, memory_key_padding_mask - bn_memory, sampled, softmax = self.discrete_bottleneck(encoder_output) - memory = encoder_output if self.bypass_bottleneck else bn_memory + def sample_forward(self, memory: Tensor) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]: + """ + Given the "memory" from forward(), run the sample_and_redict module. + See documentation for forward() of class SampleAndPredict for more info. - return (memory, bn_memory, pos_emb, sampled, softmax, memory_key_padding_mask) + Returns (sampled, softmax, positive_embed_shifted, negative_embed_shifted), + where positive_embed_shifted, for instance, is positive_embed + shifted by one so that positive_embed_shifted[t] == positive_embed[t-1], as in: + (T, N, E) = positive_embed.shape + positive_embed_shifted = torch.cat((torch.zeros(1, N, E), positive_embed[:-1,:,:]), dim=0) + + """ + (sampled, softmax, positive_embed, negative_embed) = self.sample_and_predict(memory) + + (T, N, E) = memory.shape + device = memory.device + zeros = torch.zeros(1, N, E).to(memory.device) + negative_embed_shifted = torch.cat((zeros, negative_embed[:-1,:,:]), dim=0) + positive_embed_shifted = torch.cat((zeros, positive_embed[:-1,:,:]), dim=0) + + return (sampled, softmax, positive_embed_shifted, negative_embed_shifted) def decoder_forward( self, @@ -479,11 +483,10 @@ class BidirectionalConformer(nn.Module): def self_prediction_forward( self, - bn_memory_shifted: torch.Tensor, + negative_embed_shifted: torch.Tensor, memory_key_padding_mask: torch.Tensor, sampled: torch.Tensor, - softmax: Optional[torch.Tensor], - reverse_grad: bool = True) -> Tensor: + softmax: Optional[torch.Tensor]) -> Tensor: """ Returns the total log-prob of the the labels sampled in the discrete bottleneck layer, as predicted using a relatively simple model that @@ -491,54 +494,38 @@ class BidirectionalConformer(nn.Module): [Appears on the denominator of an expression for mutual information]. Args: - bn_memory_shifted: - It's the bn_memory output of forward(), with shape [T, N, E], shifted - by one so that bn_shifted_memory[t] == bn_memory[t-1], as in: - (T, N, E) = bn_memory.shape - bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0) + negative_embed_shifted: + The negative_embed_shifted output of self.sample_forward(), with shape [T, N, E] memory_key_padding_mask: The padding mask from the encoder, of shape [N, T], boolean, True for masked locations. sampled: sampled and interpolated one-hot values, as a Tensor of shape [T, N, C] - where C corresponds to `discrete_bottleneck_tot_classes` + where C corresponds to `discretization_tot_classes` as given to the constructor. This will be needed for the 'reverse' model. softmax: is a "soft" version of `sampled`; if None, will default to `sampled`. - reverse_grad: will likely be true. If true, the gradient is reversed twice - in this computation, so that we train predictors with the correct - gradient, i.e. to predict, not anti-predict (since the return value - of this function will appear with positive, not negative, sign in the - loss function, so will be minimized). - The gradient w.r.t. the non-self inputs to this function, though (i.e. - bn_memory_shifted, sampled, softmax) will not be reversed, though. + Returns: - A scalar tensor, the **sum** of label smoothing loss over utterances + A scalar tensor, the **sum** of the log-prob loss over utterances in the batch without any normalization. """ - if reverse_grad: - # Reversing gradient for bn_memory_shifted puts the gradient back into - # the correct sign; we reversed it in - # self.discrete_bottleneck.compute_prob(), in order that - # self.self.predictor_encoder can learn to predict. - # (You have to read the code in reverse, to reason about - # what happens to the gradients). - bn_memory_shifted = ReverseGrad.apply(bn_memory_shifted) - # no mask is needed for self_predictor_encoder; its CNN - # layer uses left-padding only, making it causal. - predictor = self.self_predictor_encoder(bn_memory_shifted) + # layer uses left-padding only, making it causal, so the mask + # is redundant (it wouldn't affect any of the + # outputs we care about). + predictor = self.self_predictor_encoder(negative_embed_shifted) - prob = self.discrete_bottleneck.compute_prob(predictor, - sampled, softmax, - memory_key_padding_mask, - reverse_grad=True) + prob = self.sample_and_predict.compute_prob(predictor, + sampled, softmax, + memory_key_padding_mask, + reverse_grad=True) return prob def reverse_decoder_forward( self, - bn_memory_shifted: torch.Tensor, + positive_embed_shifted: torch.Tensor, memory_key_padding_mask: torch.Tensor, sampled: torch.Tensor, softmax: Optional[torch.Tensor], @@ -553,15 +540,14 @@ class BidirectionalConformer(nn.Module): supervision word-sequence. Args: - bn_memory_shifted: - It's the bn_memory output of forward(), with shape [T, N, E], shifted - by one so that shifted_memory[t] == bn_memory[t-1], as in: - (T, N, E) = memory.shape - bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0) + positive_embed_shifted: + It's the positive_embed_shifted output of self.sample_forward(), with + shape [T, N, E] memory_key_padding_mask: The padding mask from the encoder. - sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes` - as given to the constructor. This will be needed for the 'reverse' model. + sampled: is a Tensor of shape [T, N, C] where C corresponds to + `discretization_tot_classes` as given to the constructor. + This will be needed for the 'reverse' model. softmax: is a "soft" version of `sampled`; if None, will default to `sampled`. token_ids: A list-of-list IDs. Each sublist contains IDs for an utterance. @@ -583,7 +569,7 @@ class BidirectionalConformer(nn.Module): token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ] tokens_padded = pad_sequence(token_ids_tensors, batch_first=True, - padding_value=padding_id).to(bn_memory_shifted.device) + padding_value=padding_id).to(positive_embed_shifted.device) print("tokens_padded = ", tokens_padded) tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id) @@ -601,17 +587,17 @@ class BidirectionalConformer(nn.Module): src_key_padding_mask=tokens_key_padding_mask) # token_memory is of shape (S, N, C), if S is length of token sequence. - T = bn_memory_shifted.shape[0] + T = positive_embed_shifted.shape[0] # the targets, here, are the hidden discrete symbols we are predicting - tgt_mask = generate_square_subsequent_mask(T, device=bn_memory_shifted.device) + tgt_mask = generate_square_subsequent_mask(T, device=positive_embed_shifted.device) hidden_predictor = self.reverse_decoder( - tgt=bn_memory_shifted, + tgt=positive_embed_shifted, memory=token_memory, tgt_mask=tgt_mask, memory_key_padding_mask=tokens_key_padding_mask) - total_prob = self.discrete_bottleneck.compute_prob( + total_prob = self.sample_and_predict.compute_prob( hidden_predictor, sampled, softmax, @@ -676,6 +662,9 @@ class ReverseGrad(torch.autograd.Function): def backward(ctx, x_grad): return -x_grad +def reverse_gradient(x: Tensor) -> Tensor: + return ReverseGrad.apply(x) + class DebugGrad(torch.autograd.Function): @staticmethod @@ -693,19 +682,21 @@ class DebugGrad(torch.autograd.Function): -class DiscreteBottleneck(nn.Module): - """ - This layer forces its input through an information bottleneck via - a discretization operation with sampling, and allows you to - predict the likelihood of those discretized values. - We use the torch-flow-sampling - package for this, to provide a differentiable softmax that should be - much better than Gumbel in terms of actually giving an information - bottleneck. +class SampleAndPredict(nn.Module): + """This module discretizes its input and lets you predict the + discrete classes. + We use the torch-flow-sampling package for this, to provide a differentiable + softmax that should be much better than Gumbel in terms of actually giving + an information bottleneck. + (However, if straight_through_scale == 1.0, which actually seems + to be working fine so far, it's the same as just sampling from the + categorical distribution and using straight-through derivatives + (i.e. the derivatives are as if that output had just been a softmax). + This may depend somewhat on the model; straight_through_scale == 0.0 + is definitely safer from a correctness point of view. Args: - dim: The input and output dimension of the discrete bottleneck - operation. + dim: The input feature dimension tot_classes: The total number of classes (across all groups of classes); each group is separately discretized num_groups: The number of groups of classes; discretization @@ -728,10 +719,10 @@ class DiscreteBottleneck(nn.Module): 1.0 - straight_through_scale. min_prob_ratio: For any class whose average softmax output, for a given minibatch, is less than - min_prob_ratio times - include_predictor: If true, include the parameters - necessary to predict the likelihoods of the - classes from some kind of input embedding. + min_prob_ratio times the average probability, + boost its probability; this is a mechanism + to avoid "losing" classes, we are hoping it won't really + be necessary in practice. """ def __init__( self, @@ -739,12 +730,11 @@ class DiscreteBottleneck(nn.Module): tot_classes: int, num_groups: int, interp_prob: float = 1.0, - straight_through_scale: float = 0.333, + straight_through_scale: float = 0.0, min_prob_ratio: float = 0.1, - include_predictor: bool = True ): - super(DiscreteBottleneck, self).__init__() - self.norm_in = nn.LayerNorm(dim) + super(SampleAndPredict, self).__init__() + self.linear1 = nn.Linear(dim, tot_classes) self.num_groups = num_groups @@ -753,44 +743,49 @@ class DiscreteBottleneck(nn.Module): self.min_prob_ratio = min_prob_ratio self.tot_classes = tot_classes self.classes_per_group = tot_classes // num_groups + + # prob_boost relates to the min_prob_ratio setting. It's not configurable for now. self.prob_boost = 1.0e-05 # class_probs is a rolling mean of the output of the sampling operation. # When any element of it gets below self.min_prob_ratio / self.classes_per_group, # we boost the class's probability by adding self.prob_boost to # that element of self.class_offset - self.class_probs_decay = 0.9 + self.class_probs_decay = 0.95 self.register_buffer('class_probs', torch.ones(tot_classes) / self.classes_per_group) # class_offsets is a bias term that we add to logits before the sampling # operation in order to enforce that no class is too infrequent # (c.f. 'min_prob_ratio'). self.register_buffer('class_offsets', torch.zeros(tot_classes)) - self.linear2 = nn.Linear(tot_classes, dim, bias=False) - self.norm_out = nn.LayerNorm(dim) - if include_predictor: - # pred_linear predicts the class probabilities from a predictor - # embedding. - self.pred_linear = nn.Linear(dim, tot_classes) + # pred_linear predicts the class probabilities from a predictor + # embedding of dimension 'dim' supplied by the user. + self.pred_linear = nn.Linear(dim, tot_classes) + + if self.num_groups > 1: + # We predict the logprobs of each group from the outputs of the + # previous groups. This is done via a masked multiply, where + # the masking operates on blocks. This projects from [all but + # the last group] to [all but the first group], so the diagonal + # of the mask can be 1, not 0, saving compute.. + d = tot_classes - self.classes_per_group + c = self.classes_per_group + self.pred_cross = nn.Parameter(torch.zeros(d, d)) + # If d == 4 and c == 2, the expression below has the following value + # (treat True as 1 and False as 0). + #tensor([[ True, True, False, False], + # [ True, True, False, False], + # [ True, True, True, True], + # [ True, True, True, True]]) + self.register_buffer('pred_cross_mask', + ((torch.arange(d) // c).unsqueeze(1) >= + (torch.arange(d) // c).unsqueeze(0))) + + # linear2 and post_layer_norm come after the sampling. + self.linear2 = nn.Linear(tot_classes, dim) + self.post_layer_norm = nn.LayerNorm(dim) - if self.num_groups > 1: - # We predict the logprobs of each group from the outputs of the - # previous groups. This is done via a masked multiply, where - # the masking operates on blocks. This projects from [all but - # the last group] to [all but the first group], so the diagonal - # of the mask can be 1, not 0, saving compute.. - d = tot_classes - self.classes_per_group - c = self.classes_per_group - self.pred_cross = nn.Parameter(torch.zeros(d, d)) - # If d == 4 and c == 2, the expression below has the following value - # (treat True as 1 and False as 0). - #tensor([[ True, True, False, False], - # [ True, True, False, False], - # [ True, True, True, True], - # [ True, True, True, True]]) - self.register_buffer('pred_cross_mask', - ((torch.arange(d) // c).unsqueeze(1) >= (torch.arange(d) // c).unsqueeze(0))) self._reset_parameters() def _reset_parameters(self): @@ -798,17 +793,14 @@ class DiscreteBottleneck(nn.Module): torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5)) - def forward(self, x: Tensor, need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def forward(self, x: Tensor, need_softmax: bool = True) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]: """ - Forward computation. + Forward computation. See also compute_prob(). Args: x: The input tensor, of shape (S, N, E) where S is the sequence length, N is the batch size and E is the embedding dim. - Returns (embeddding, sampled, softmax), where: + Returns (sampled, softmax, positive_embed, negative_embed), where: - embedding: of shape (S, N, E) where E is the embedding dimension (`dim` arg - to the constructor), this is the output embedding; it is projected from the - sampled class probabilities. sampled: of shape (S, N, C) where C is the `tot_classes` to the constructor, these are the sampled one-hot vectors or interpolations thereof. They will be needed if we try to predict the discrete values @@ -820,19 +812,31 @@ class DiscreteBottleneck(nn.Module): expectation of the result of sampling -> lower-variance derivatives. This is unnecessary if straight_through_scale == 1.0, since in that case it would not affect the backpropagated derivatives. + positive_embed: The samples projected back down to the embedding + dimension, and layer-normed (`dim` passed to the constructor). + negative_embed: This is numerically the same value as positive_embed, + but has its gradients reversed prior to the projection and + LayerNorm. This is intended to be used for terms that appear + with opposite sign in the loss function, to be fed to + something whose gradient is already (going to be) reversed: + specifically, the self-prediction network. """ - x = self.norm_in(x) * 5 # * 5 gives lower entropy.. starts training faster.. - x = self.linear1(x) - x = x + self.class_offsets + x = self.linear1(x * 5) # multiplying 5 gives lower entropy, makes it + # begin training faster.. + + if self.min_prob_ratio > 0.0: + x = x + self.class_offsets (S, N, tot_classes) = x.shape x = x.reshape(S, N, self.num_groups, self.classes_per_group) - # This is a little wasteful since we already compute the softmax - # inside 'flow_sample'. + # This is a little wasteful since we already compute the softmax inside + # 'flow_sample'. softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None + if random.random() < 0.001: + # Some info that's useful for debug. softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group) logsoftmax_temp = (softmax_temp + 1.0e-20).log() negentropy = (softmax_temp * logsoftmax_temp).sum(-1).mean() @@ -841,8 +845,9 @@ class DiscreteBottleneck(nn.Module): global_log_softmax = (global_softmax + 1.0e-20).log() global_negentropy = (global_softmax * global_log_softmax).sum(-1).mean() - print("Entropy = ", -negentropy.to('cpu').item(), ", averaged entropy = ", -global_negentropy.to('cpu').item()) - + print("SampleAndPredict: entropy = ", + -negentropy.to('cpu').item(), ", averaged entropy = ", + -global_negentropy.to('cpu').item()) x = torch_flow_sampling.flow_sample(x, @@ -854,19 +859,22 @@ class DiscreteBottleneck(nn.Module): sampled = x - if self.training and False: + if self.training and self.min_prob_ratio > 0.0: mean_class_probs = torch.mean(x.detach(), dim=(0,1)) self.class_probs = (self.class_probs * self.class_probs_decay + mean_class_probs * (1.0 - self.class_probs_decay)) prob_floor = self.min_prob_ratio / self.classes_per_group self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost - embedding = self.norm_out(self.linear2(x)) - #if random.random() < 0.01: - # return (embedding, DebugGrad.apply(sampled, "sampled"), DebugGrad.apply(softmax, "softmax")) - #else: - return (embedding, sampled, softmax) + positive_embed = self.post_layer_norm(self.linear2(sampled)) + negative_embed = self.post_layer_norm(self.linear2(reverse_gradient(sampled))) + + if random.random() < 0.002: + return (DebugGrad.apply(sampled, "sampled"), DebugGrad.apply(softmax, "softmax"), + positive_embed, negative_embed) + else: + return (sampled, softmax, positive_embed, negative_embed) def compute_prob(self, x: Tensor, sampled: Tensor, softmax: Optional[Tensor], @@ -891,23 +899,23 @@ class DiscreteBottleneck(nn.Module): padding_mask: Optionally, a boolean tensor of shape (N, S), i.e. (batch_size, sequence_length), with True in masked positions that are to be ignored in the sum of probabilities. - reverse_grad: If true, negate the gradient that is passed back to 'x' and to the modules self.pred_linear and pred_cross. This will be useful in computing a loss function that has a likelihood term with negative sign (i.e. the self-prediction). - We'll later need negate the gradient one more more time - where we give the input to the prediction module that - generated 'x'. + We'll later need to negate this gradient one more more time + (it's expected, when reverse_grad == True, that x would + derive somehow from `negative_embed`, so that the gradient + will eventually go back to the correct sign.) Returns a scalar Tensor represnting the total probability. """ if reverse_grad: - sampled = ReverseGrad.apply(sampled) + sampled = reverse_gradient(sampled) if softmax is None: softmax = sampled elif reverse_grad: - softmax = ReverseGrad.apply(softmax) + softmax = reverse_gradient(softmax) logprobs = self.pred_linear(x) @@ -942,7 +950,7 @@ class DiscreteBottleneck(nn.Module): tot_prob = (logprobs * softmax).sum() if reverse_grad: - tot_prob = ReverseGrad.apply(tot_prob) + tot_prob = reverse_gradient(tot_prob) return tot_prob @@ -1814,28 +1822,27 @@ def _test_bidirectional_conformer(): print("tokens = ", tokens) print("supervision = ", supervision) # memory: [T, N, C] - (memory, bn_memory, pos_emb, sampled, softmax, key_padding_mask) = m(feats, supervision) + (memory, pos_emb, key_padding_mask) = m(feats, supervision) # ctc_output: [N, T, C]. ctc_output = m.ctc_encoder_forward(memory, pos_emb, key_padding_mask) + (sampled, softmax, positive_embed_shifted, negative_embed_shifted) = m.sample_forward(memory) + decoder_logprob = m.decoder_forward(memory, key_padding_mask, tokens, sos_id=1, eos_id=2) print("decoder logprob = ", decoder_logprob) - (T, N, E) = bn_memory.shape - bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0) - reverse_decoder_logprob = m.reverse_decoder_forward( - bn_memory_shifted, key_padding_mask, + positive_embed_shifted, key_padding_mask, sampled, softmax, tokens, sos_id=1, eos_id=2, padding_id=0) print("reverse decoder logprob = ", reverse_decoder_logprob) self_prediction_logprob = m.self_prediction_forward( - bn_memory_shifted, key_padding_mask, + negative_embed_shifted, key_padding_mask, sampled, softmax) print("self prediction logprob = ", self_prediction_logprob) @@ -1850,24 +1857,24 @@ def _test_discrete_bottleneck(): tot_classes = 256 num_groups = 8 interp_prob = 1.0 - straight_through_scale = 0.0 # will change + straight_through_scale = 1.0 # will change need_softmax = True - b = DiscreteBottleneck(dim, tot_classes, num_groups, - interp_prob, straight_through_scale).to(device) - - self_predictor = nn.Linear(tot_classes, dim).to(device) + b = SampleAndPredict(dim, tot_classes, num_groups, + interp_prob, straight_through_scale).to(device) from_feats_predictor = nn.Linear(dim, dim).to(device) - model = nn.ModuleList([b, self_predictor, from_feats_predictor]) + from_negative_embed_predictor = nn.Linear(dim, dim).to(device) + + model = nn.ModuleList([b, from_feats_predictor, from_negative_embed_predictor]) model.train() optim = torch.optim.Adam(params=model.parameters(), lr=3.0e-04) - scale = 0.3 # determines the feature correlation..should be between 0 and 1. + scale = 0.5 # determines the feature correlation..should be between 0 and 1. #https://en.wikipedia.org/wiki/Mutual_information#Linear_correlation, -0.5 log(1 - rho^2).. # scale corresponds to rho^2, rho being sqrt(scale). mutual_information = dim * -0.5 * math.log(1.0 - scale) @@ -1886,37 +1893,34 @@ def _test_discrete_bottleneck(): #print(f"norm(feats) ={feats.norm()} vs. norm(feats2) = {feats2.norm()}") - bn_memory, sampled, softmax = b(feats) + sampled, softmax, positive_embed, negative_embed = b(feats) + + E = dim + negative_embed_shifted = torch.cat((torch.zeros(1, N, E).to(device), + negative_embed[:-1,:,:]), dim=0) + positive_embed_shifted = torch.cat((torch.zeros(1, N, E).to(device), + positive_embed[:-1,:,:]), dim=0) # using feats2 instead of feats will limit the mutual information, # to the MI between feats and feats2, which we computed and printed # above as mutual_information. predictor = from_feats_predictor(feats2) + prob = b.compute_prob(predictor, sampled, softmax) if True: - sampled_reversed = ReverseGrad.apply(sampled) - predictor_reversed = self_predictor(sampled_reversed) + predictor_shifted = from_negative_embed_predictor(negative_embed_shifted) - if True: - predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device), - predictor_reversed[:-1,:,:]), dim=0) - else: - # skip shifting.. want to see the effect.. - predictor_reversed_shifted = predictor_reversed - - self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax, + self_prob = b.compute_prob(predictor_shifted, sampled, softmax, reverse_grad=True) - normalized_self_prob = (self_prob / (T * N)).to('cpu').item() + normalized_self_prob = (self_prob / (T * N)) - normalized_prob = (prob / (T * N)).to('cpu').item() + normalized_prob = (prob / (T * N)) - loss = -(prob - self_prob) - - normalized_loss = loss / (T * N) + normalized_loss = -(normalized_prob - normalized_self_prob) if i % 200 == 0: - print(f"Epoch {epoch}, iteration {i}, normalized loss/frame is {-normalized_prob} - {-normalized_self_prob} = {normalized_loss.to('cpu').item()}") + print(f"Epoch {epoch}, iteration {i}, normalized loss/frame is {-normalized_prob.to('cpu').item()} - {-normalized_self_prob.to('cpu').item()} = {normalized_loss.to('cpu').item()}") normalized_loss.backward()