Some progress in testing..

This commit is contained in:
Daniel Povey 2021-09-18 15:00:27 +08:00
parent a20d490332
commit 38081bc3e3

View File

@ -320,7 +320,8 @@ class BidirectionalConformer(nn.Module):
if num_self_predictor_layers > 0: if num_self_predictor_layers > 0:
encoder_layer = SimpleCausalEncoderLayer(d_model, encoder_layer = SimpleCausalEncoderLayer(d_model,
dropout=dropout) dropout=dropout)
self.self_predictor_encoder = encoder_layer self.self_predictor_encoder = nn.Sequential(*[copy.deepcopy(encoder_layer)
for _ in range(num_self_predictor_layers)])
self.discrete_bottleneck = DiscreteBottleneck( self.discrete_bottleneck = DiscreteBottleneck(
@ -478,11 +479,11 @@ class BidirectionalConformer(nn.Module):
def self_prediction_forward( def self_prediction_forward(
self, self,
memory_shifted: torch.Tensor, bn_memory_shifted: torch.Tensor,
memory_key_padding_mask: torch.Tensor, memory_key_padding_mask: torch.Tensor,
sampled: torch.Tensor, sampled: torch.Tensor,
softmax: Optional[torch.Tensor], softmax: Optional[torch.Tensor],
reverse_gradient: bool = True) -> Tensor: reverse_grad: bool = True) -> Tensor:
""" """
Returns the total log-prob of the the labels sampled in the discrete Returns the total log-prob of the the labels sampled in the discrete
bottleneck layer, as predicted using a relatively simple model that bottleneck layer, as predicted using a relatively simple model that
@ -490,11 +491,11 @@ class BidirectionalConformer(nn.Module):
[Appears on the denominator of an expression for mutual information]. [Appears on the denominator of an expression for mutual information].
Args: Args:
memory_shifted: bn_memory_shifted:
It's the output of forward(), with shape [T, N, E], shifted It's the bn_memory output of forward(), with shape [T, N, E], shifted
by one so that shifted_memory[t] == memory[t-1], as in: by one so that bn_shifted_memory[t] == bn_memory[t-1], as in:
(T, N, E) = memory.shape (T, N, E) = bn_memory.shape
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0) bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from the encoder, of shape [N, T], boolean, True The padding mask from the encoder, of shape [N, T], boolean, True
for masked locations. for masked locations.
@ -503,41 +504,41 @@ class BidirectionalConformer(nn.Module):
as given to the constructor. This will be needed for the 'reverse' as given to the constructor. This will be needed for the 'reverse'
model. model.
softmax: is a "soft" version of `sampled`; if None, will default to `sampled`. softmax: is a "soft" version of `sampled`; if None, will default to `sampled`.
reverse_gradient: will likely be true. If true, the gradient is reversed twice reverse_grad: will likely be true. If true, the gradient is reversed twice
in this computation, so that we train predictors with the correct in this computation, so that we train predictors with the correct
gradient, i.e. to predict, not anti-predict (since the return value gradient, i.e. to predict, not anti-predict (since the return value
of this function will appear with positive, not negative, sign in the of this function will appear with positive, not negative, sign in the
loss function, so will be minimized). loss function, so will be minimized).
The gradient w.r.t. the non-self inputs to this function, though (i.e. The gradient w.r.t. the non-self inputs to this function, though (i.e.
memory_shifted, sampled, softmax) will not be reversed, though. bn_memory_shifted, sampled, softmax) will not be reversed, though.
Returns: Returns:
A scalar tensor, the **sum** of label smoothing loss over utterances A scalar tensor, the **sum** of label smoothing loss over utterances
in the batch without any normalization. in the batch without any normalization.
""" """
if reverse_gradient: if reverse_grad:
# Reversing gradient for memory_shifted puts the gradient back into # Reversing gradient for bn_memory_shifted puts the gradient back into
# the correct sign; we reversed it in # the correct sign; we reversed it in
# self.discrete_bottleneck.compute_prob(), in order that # self.discrete_bottleneck.compute_prob(), in order that
# self.self.predictor_encoder can learn to predict. # self.self.predictor_encoder can learn to predict.
# (You have to read the code in reverse, to reason about # (You have to read the code in reverse, to reason about
# what happens to the gradients). # what happens to the gradients).
memory_shifted = ReverseGrad.apply(memory_shifted) bn_memory_shifted = ReverseGrad.apply(bn_memory_shifted)
# no mask is needed for self_predictor_encoder; its CNN # no mask is needed for self_predictor_encoder; its CNN
# layer uses left-padding only, making it causal. # layer uses left-padding only, making it causal.
predictor = self.self_predictor_encoder(memory_shifted) predictor = self.self_predictor_encoder(bn_memory_shifted)
prob = self.discrete_bottleneck.compute_prob(predictor, prob = self.discrete_bottleneck.compute_prob(predictor,
sampled, softmax, sampled, softmax,
memory_key_padding_mask, memory_key_padding_mask,
reverse_gradient=True) reverse_grad=True)
return prob return prob
def reverse_decoder_forward( def reverse_decoder_forward(
self, self,
memory_shifted: torch.Tensor, bn_memory_shifted: torch.Tensor,
memory_key_padding_mask: torch.Tensor, memory_key_padding_mask: torch.Tensor,
sampled: torch.Tensor, sampled: torch.Tensor,
softmax: Optional[torch.Tensor], softmax: Optional[torch.Tensor],
@ -552,11 +553,11 @@ class BidirectionalConformer(nn.Module):
supervision word-sequence. supervision word-sequence.
Args: Args:
memory_shifted: bn_memory_shifted:
It's the output of forward(), with shape [T, N, E], shifted It's the bn_memory output of forward(), with shape [T, N, E], shifted
by one so that shifted_memory[t] == memory[t-1], as in: by one so that shifted_memory[t] == bn_memory[t-1], as in:
(T, N, E) = memory.shape (T, N, E) = memory.shape
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0) bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from the encoder. The padding mask from the encoder.
sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes` sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes`
@ -582,7 +583,7 @@ class BidirectionalConformer(nn.Module):
token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ] token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ]
tokens_padded = pad_sequence(token_ids_tensors, batch_first=True, tokens_padded = pad_sequence(token_ids_tensors, batch_first=True,
padding_value=padding_id).to(memory_shifted.device) padding_value=padding_id).to(bn_memory_shifted.device)
print("tokens_padded = ", tokens_padded) print("tokens_padded = ", tokens_padded)
tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id) tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id)
@ -600,12 +601,12 @@ class BidirectionalConformer(nn.Module):
src_key_padding_mask=tokens_key_padding_mask) src_key_padding_mask=tokens_key_padding_mask)
# token_memory is of shape (S, N, C), if S is length of token sequence. # token_memory is of shape (S, N, C), if S is length of token sequence.
T = memory_shifted.shape[0] T = bn_memory_shifted.shape[0]
# the targets, here, are the hidden discrete symbols we are predicting # the targets, here, are the hidden discrete symbols we are predicting
tgt_mask = generate_square_subsequent_mask(T, device=memory_shifted.device) tgt_mask = generate_square_subsequent_mask(T, device=bn_memory_shifted.device)
hidden_predictor = self.reverse_decoder( hidden_predictor = self.reverse_decoder(
tgt=memory_shifted, tgt=bn_memory_shifted,
memory=token_memory, memory=token_memory,
tgt_mask=tgt_mask, tgt_mask=tgt_mask,
memory_key_padding_mask=tokens_key_padding_mask) memory_key_padding_mask=tokens_key_padding_mask)
@ -780,7 +781,7 @@ class DiscreteBottleneck(nn.Module):
torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5))
def forward(self, x: Tensor, need_softmax: bool = False) -> Tuple[Tensor, Tensor, Tensor]: def forward(self, x: Tensor, need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
""" """
Forward computation. Forward computation.
Args: Args:
@ -813,7 +814,7 @@ class DiscreteBottleneck(nn.Module):
# This is a little wasteful since we already compute the softmax # This is a little wasteful since we already compute the softmax
# inside 'flow_sample'. # inside 'flow_sample'.
softmax = x.softmax().reshape(S, N, tot_classes) if need_softmax else None softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
x = torch_flow_sampling.flow_sample(x, x = torch_flow_sampling.flow_sample(x,
interp_prob=self.interp_prob, interp_prob=self.interp_prob,
@ -829,15 +830,15 @@ class DiscreteBottleneck(nn.Module):
self.class_probs = (self.class_probs * self.class_probs_decay + self.class_probs = (self.class_probs * self.class_probs_decay +
mean_class_probs * (1.0 - self.class_probs_decay)) mean_class_probs * (1.0 - self.class_probs_decay))
prob_floor = self.min_prob_ratio / self.classes_per_group prob_floor = self.min_prob_ratio / self.classes_per_group
self.class_offsets += (self.class_probs > prob_floor) * self.prob_boost self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost
embedding = self.norm_out(self.linear2(x)) embedding = self.norm_out(self.linear2(x))
return (embedding, sampled, softmax) return (embedding, sampled, softmax)
def compute_prob(self, x: Tensor, sampled: Tensor, softmax: Optional[Tensor], def compute_prob(self, x: Tensor, sampled: Tensor, softmax: Optional[Tensor],
padding_mask: Optional[Tensor], padding_mask: Optional[Tensor] = None,
reverse_gradient: bool = False) -> Tensor: reverse_grad: bool = False) -> Tensor:
""" """
Compute the total probability of the sampled probabilities, given Compute the total probability of the sampled probabilities, given
some kind of predictor x (which we assume should not have access some kind of predictor x (which we assume should not have access
@ -846,7 +847,8 @@ class DiscreteBottleneck(nn.Module):
x: The predictor tensor, of shape (S, N, E) where S is the x: The predictor tensor, of shape (S, N, E) where S is the
sequence length, N is the batch size and E is the embedding dim sequence length, N is the batch size and E is the embedding dim
(`dim` arg to __init__()) (`dim` arg to __init__()). This is projected from `sampled`
with a learnable matrix.
sampled: A tensor of shape (S, N, C) where C is the `tot_classes` sampled: A tensor of shape (S, N, C) where C is the `tot_classes`
to the constructor, containing the sampled probabilities. to the constructor, containing the sampled probabilities.
softmax: A tensor of shape (S, N, C), this is the "smooth" version softmax: A tensor of shape (S, N, C), this is the "smooth" version
@ -857,7 +859,7 @@ class DiscreteBottleneck(nn.Module):
(batch_size, sequence_length), with True in masked positions (batch_size, sequence_length), with True in masked positions
that are to be ignored in the sum of probabilities. that are to be ignored in the sum of probabilities.
reverse_gradient: If true, negate the gradient that is passed back reverse_grad: If true, negate the gradient that is passed back
to 'x' and to the modules self.pred_linear and pred_cross. to 'x' and to the modules self.pred_linear and pred_cross.
This will be useful in computing a loss function that has This will be useful in computing a loss function that has
a likelihood term with negative sign (i.e. the self-prediction). a likelihood term with negative sign (i.e. the self-prediction).
@ -867,7 +869,7 @@ class DiscreteBottleneck(nn.Module):
Returns a scalar Tensor represnting the total probability. Returns a scalar Tensor represnting the total probability.
""" """
if reverse_gradient: if reverse_grad:
sampled = ReverseGrad.apply(sampled) sampled = ReverseGrad.apply(sampled)
if softmax is None: if softmax is None:
softmax = sampled softmax = sampled
@ -906,7 +908,7 @@ class DiscreteBottleneck(nn.Module):
else: else:
tot_prob = (logprobs * softmax).sum() tot_prob = (logprobs * softmax).sum()
if reverse_gradient: if reverse_grad:
tot_prob = ReverseGrad.apply(tot_prob) tot_prob = ReverseGrad.apply(tot_prob)
return tot_prob return tot_prob
@ -1789,18 +1791,18 @@ def _test_bidirectional_conformer():
eos_id=2) eos_id=2)
print("decoder logprob = ", decoder_logprob) print("decoder logprob = ", decoder_logprob)
(T, N, E) = memory.shape (T, N, E) = bn_memory.shape
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0) bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
reverse_decoder_logprob = m.reverse_decoder_forward( reverse_decoder_logprob = m.reverse_decoder_forward(
memory_shifted, key_padding_mask, bn_memory_shifted, key_padding_mask,
sampled, softmax, tokens, sampled, softmax, tokens,
sos_id=1, eos_id=2, padding_id=0) sos_id=1, eos_id=2, padding_id=0)
print("reverse decoder logprob = ", reverse_decoder_logprob) print("reverse decoder logprob = ", reverse_decoder_logprob)
self_prediction_logprob = m.self_prediction_forward( self_prediction_logprob = m.self_prediction_forward(
memory_shifted, key_padding_mask, bn_memory_shifted, key_padding_mask,
sampled, softmax) sampled, softmax)
print("self prediction logprob = ", self_prediction_logprob) print("self prediction logprob = ", self_prediction_logprob)
@ -1809,5 +1811,67 @@ def _test_bidirectional_conformer():
loss.backward() loss.backward()
def _test_discrete_bottleneck():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dim = 128
tot_classes = 256
num_groups = 4
interp_prob = 0.8
straight_through_scale = 0.0 # will change
need_softmax = True
b = DiscreteBottleneck(dim, tot_classes, num_groups,
interp_prob, straight_through_scale).to(device)
self_predictor = nn.Linear(tot_classes, dim).to(device)
optim = torch.optim.SGD(params=(list(b.parameters()) + list(self_predictor.parameters())),
lr=1.0e-03, momentum=0.99)
for epoch in range(10):
state_dict = dict()
state_dict['b'] = b.state_dict()
state_dict['s'] = self_predictor.state_dict()
torch.save(state_dict, f'epoch-{epoch}.pt')
for i in range(1000):
# TODO: also test padding_mask
T = 300
N = 10
feats = torch.randn(T, N, dim, device=device)
bn_memory, sampled, softmax = b(feats)
predictor = feats # Could also use `bn_memory`, perhaps. But using
# predictor because it contains everything, will give
# us a bound on max information..
prob = b.compute_prob(feats, sampled, softmax)
sampled_reversed = ReverseGrad.apply(sampled)
predictor_reversed = self_predictor(sampled_reversed)
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
predictor_reversed[:-1,:,:]), dim=0)
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
reverse_grad=True)
normalized_prob = (prob / (T * N)).to('cpu').item()
normalized_self_prob = (self_prob / (T * N)).to('cpu').item()
loss = -(prob - self_prob)
normalized_loss = loss / (T * N)
if i % 200 == 0:
print(f"Epoch {epoch}, iteration {i}, normalized loss/frame is {-normalized_prob} - {-normalized_self_prob} = {normalized_loss.to('cpu').item()}")
normalized_loss.backward()
optim.step()
optim.zero_grad()
if __name__ == '__main__': if __name__ == '__main__':
_test_bidirectional_conformer() _test_bidirectional_conformer()
_test_discrete_bottleneck()