Some progress in testing..

This commit is contained in:
Daniel Povey 2021-09-18 15:00:27 +08:00
parent a20d490332
commit 38081bc3e3

View File

@ -320,7 +320,8 @@ class BidirectionalConformer(nn.Module):
if num_self_predictor_layers > 0:
encoder_layer = SimpleCausalEncoderLayer(d_model,
dropout=dropout)
self.self_predictor_encoder = encoder_layer
self.self_predictor_encoder = nn.Sequential(*[copy.deepcopy(encoder_layer)
for _ in range(num_self_predictor_layers)])
self.discrete_bottleneck = DiscreteBottleneck(
@ -478,11 +479,11 @@ class BidirectionalConformer(nn.Module):
def self_prediction_forward(
self,
memory_shifted: torch.Tensor,
bn_memory_shifted: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
sampled: torch.Tensor,
softmax: Optional[torch.Tensor],
reverse_gradient: bool = True) -> Tensor:
reverse_grad: bool = True) -> Tensor:
"""
Returns the total log-prob of the the labels sampled in the discrete
bottleneck layer, as predicted using a relatively simple model that
@ -490,11 +491,11 @@ class BidirectionalConformer(nn.Module):
[Appears on the denominator of an expression for mutual information].
Args:
memory_shifted:
It's the output of forward(), with shape [T, N, E], shifted
by one so that shifted_memory[t] == memory[t-1], as in:
(T, N, E) = memory.shape
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
bn_memory_shifted:
It's the bn_memory output of forward(), with shape [T, N, E], shifted
by one so that bn_shifted_memory[t] == bn_memory[t-1], as in:
(T, N, E) = bn_memory.shape
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
memory_key_padding_mask:
The padding mask from the encoder, of shape [N, T], boolean, True
for masked locations.
@ -503,41 +504,41 @@ class BidirectionalConformer(nn.Module):
as given to the constructor. This will be needed for the 'reverse'
model.
softmax: is a "soft" version of `sampled`; if None, will default to `sampled`.
reverse_gradient: will likely be true. If true, the gradient is reversed twice
reverse_grad: will likely be true. If true, the gradient is reversed twice
in this computation, so that we train predictors with the correct
gradient, i.e. to predict, not anti-predict (since the return value
of this function will appear with positive, not negative, sign in the
loss function, so will be minimized).
The gradient w.r.t. the non-self inputs to this function, though (i.e.
memory_shifted, sampled, softmax) will not be reversed, though.
bn_memory_shifted, sampled, softmax) will not be reversed, though.
Returns:
A scalar tensor, the **sum** of label smoothing loss over utterances
in the batch without any normalization.
"""
if reverse_gradient:
# Reversing gradient for memory_shifted puts the gradient back into
if reverse_grad:
# Reversing gradient for bn_memory_shifted puts the gradient back into
# the correct sign; we reversed it in
# self.discrete_bottleneck.compute_prob(), in order that
# self.self.predictor_encoder can learn to predict.
# (You have to read the code in reverse, to reason about
# what happens to the gradients).
memory_shifted = ReverseGrad.apply(memory_shifted)
bn_memory_shifted = ReverseGrad.apply(bn_memory_shifted)
# no mask is needed for self_predictor_encoder; its CNN
# layer uses left-padding only, making it causal.
predictor = self.self_predictor_encoder(memory_shifted)
predictor = self.self_predictor_encoder(bn_memory_shifted)
prob = self.discrete_bottleneck.compute_prob(predictor,
sampled, softmax,
memory_key_padding_mask,
reverse_gradient=True)
reverse_grad=True)
return prob
def reverse_decoder_forward(
self,
memory_shifted: torch.Tensor,
bn_memory_shifted: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
sampled: torch.Tensor,
softmax: Optional[torch.Tensor],
@ -552,11 +553,11 @@ class BidirectionalConformer(nn.Module):
supervision word-sequence.
Args:
memory_shifted:
It's the output of forward(), with shape [T, N, E], shifted
by one so that shifted_memory[t] == memory[t-1], as in:
bn_memory_shifted:
It's the bn_memory output of forward(), with shape [T, N, E], shifted
by one so that shifted_memory[t] == bn_memory[t-1], as in:
(T, N, E) = memory.shape
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
memory_key_padding_mask:
The padding mask from the encoder.
sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes`
@ -582,7 +583,7 @@ class BidirectionalConformer(nn.Module):
token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ]
tokens_padded = pad_sequence(token_ids_tensors, batch_first=True,
padding_value=padding_id).to(memory_shifted.device)
padding_value=padding_id).to(bn_memory_shifted.device)
print("tokens_padded = ", tokens_padded)
tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id)
@ -600,12 +601,12 @@ class BidirectionalConformer(nn.Module):
src_key_padding_mask=tokens_key_padding_mask)
# token_memory is of shape (S, N, C), if S is length of token sequence.
T = memory_shifted.shape[0]
T = bn_memory_shifted.shape[0]
# the targets, here, are the hidden discrete symbols we are predicting
tgt_mask = generate_square_subsequent_mask(T, device=memory_shifted.device)
tgt_mask = generate_square_subsequent_mask(T, device=bn_memory_shifted.device)
hidden_predictor = self.reverse_decoder(
tgt=memory_shifted,
tgt=bn_memory_shifted,
memory=token_memory,
tgt_mask=tgt_mask,
memory_key_padding_mask=tokens_key_padding_mask)
@ -780,7 +781,7 @@ class DiscreteBottleneck(nn.Module):
torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5))
def forward(self, x: Tensor, need_softmax: bool = False) -> Tuple[Tensor, Tensor, Tensor]:
def forward(self, x: Tensor, need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
"""
Forward computation.
Args:
@ -813,7 +814,7 @@ class DiscreteBottleneck(nn.Module):
# This is a little wasteful since we already compute the softmax
# inside 'flow_sample'.
softmax = x.softmax().reshape(S, N, tot_classes) if need_softmax else None
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
x = torch_flow_sampling.flow_sample(x,
interp_prob=self.interp_prob,
@ -829,15 +830,15 @@ class DiscreteBottleneck(nn.Module):
self.class_probs = (self.class_probs * self.class_probs_decay +
mean_class_probs * (1.0 - self.class_probs_decay))
prob_floor = self.min_prob_ratio / self.classes_per_group
self.class_offsets += (self.class_probs > prob_floor) * self.prob_boost
self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost
embedding = self.norm_out(self.linear2(x))
return (embedding, sampled, softmax)
def compute_prob(self, x: Tensor, sampled: Tensor, softmax: Optional[Tensor],
padding_mask: Optional[Tensor],
reverse_gradient: bool = False) -> Tensor:
padding_mask: Optional[Tensor] = None,
reverse_grad: bool = False) -> Tensor:
"""
Compute the total probability of the sampled probabilities, given
some kind of predictor x (which we assume should not have access
@ -846,7 +847,8 @@ class DiscreteBottleneck(nn.Module):
x: The predictor tensor, of shape (S, N, E) where S is the
sequence length, N is the batch size and E is the embedding dim
(`dim` arg to __init__())
(`dim` arg to __init__()). This is projected from `sampled`
with a learnable matrix.
sampled: A tensor of shape (S, N, C) where C is the `tot_classes`
to the constructor, containing the sampled probabilities.
softmax: A tensor of shape (S, N, C), this is the "smooth" version
@ -857,7 +859,7 @@ class DiscreteBottleneck(nn.Module):
(batch_size, sequence_length), with True in masked positions
that are to be ignored in the sum of probabilities.
reverse_gradient: If true, negate the gradient that is passed back
reverse_grad: If true, negate the gradient that is passed back
to 'x' and to the modules self.pred_linear and pred_cross.
This will be useful in computing a loss function that has
a likelihood term with negative sign (i.e. the self-prediction).
@ -867,7 +869,7 @@ class DiscreteBottleneck(nn.Module):
Returns a scalar Tensor represnting the total probability.
"""
if reverse_gradient:
if reverse_grad:
sampled = ReverseGrad.apply(sampled)
if softmax is None:
softmax = sampled
@ -906,7 +908,7 @@ class DiscreteBottleneck(nn.Module):
else:
tot_prob = (logprobs * softmax).sum()
if reverse_gradient:
if reverse_grad:
tot_prob = ReverseGrad.apply(tot_prob)
return tot_prob
@ -1789,18 +1791,18 @@ def _test_bidirectional_conformer():
eos_id=2)
print("decoder logprob = ", decoder_logprob)
(T, N, E) = memory.shape
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
(T, N, E) = bn_memory.shape
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
reverse_decoder_logprob = m.reverse_decoder_forward(
memory_shifted, key_padding_mask,
bn_memory_shifted, key_padding_mask,
sampled, softmax, tokens,
sos_id=1, eos_id=2, padding_id=0)
print("reverse decoder logprob = ", reverse_decoder_logprob)
self_prediction_logprob = m.self_prediction_forward(
memory_shifted, key_padding_mask,
bn_memory_shifted, key_padding_mask,
sampled, softmax)
print("self prediction logprob = ", self_prediction_logprob)
@ -1809,5 +1811,67 @@ def _test_bidirectional_conformer():
loss.backward()
def _test_discrete_bottleneck():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dim = 128
tot_classes = 256
num_groups = 4
interp_prob = 0.8
straight_through_scale = 0.0 # will change
need_softmax = True
b = DiscreteBottleneck(dim, tot_classes, num_groups,
interp_prob, straight_through_scale).to(device)
self_predictor = nn.Linear(tot_classes, dim).to(device)
optim = torch.optim.SGD(params=(list(b.parameters()) + list(self_predictor.parameters())),
lr=1.0e-03, momentum=0.99)
for epoch in range(10):
state_dict = dict()
state_dict['b'] = b.state_dict()
state_dict['s'] = self_predictor.state_dict()
torch.save(state_dict, f'epoch-{epoch}.pt')
for i in range(1000):
# TODO: also test padding_mask
T = 300
N = 10
feats = torch.randn(T, N, dim, device=device)
bn_memory, sampled, softmax = b(feats)
predictor = feats # Could also use `bn_memory`, perhaps. But using
# predictor because it contains everything, will give
# us a bound on max information..
prob = b.compute_prob(feats, sampled, softmax)
sampled_reversed = ReverseGrad.apply(sampled)
predictor_reversed = self_predictor(sampled_reversed)
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
predictor_reversed[:-1,:,:]), dim=0)
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
reverse_grad=True)
normalized_prob = (prob / (T * N)).to('cpu').item()
normalized_self_prob = (self_prob / (T * N)).to('cpu').item()
loss = -(prob - self_prob)
normalized_loss = loss / (T * N)
if i % 200 == 0:
print(f"Epoch {epoch}, iteration {i}, normalized loss/frame is {-normalized_prob} - {-normalized_self_prob} = {normalized_loss.to('cpu').item()}")
normalized_loss.backward()
optim.step()
optim.zero_grad()
if __name__ == '__main__':
_test_bidirectional_conformer()
_test_discrete_bottleneck()