mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Some progress in testing..
This commit is contained in:
parent
a20d490332
commit
38081bc3e3
@ -320,7 +320,8 @@ class BidirectionalConformer(nn.Module):
|
||||
if num_self_predictor_layers > 0:
|
||||
encoder_layer = SimpleCausalEncoderLayer(d_model,
|
||||
dropout=dropout)
|
||||
self.self_predictor_encoder = encoder_layer
|
||||
self.self_predictor_encoder = nn.Sequential(*[copy.deepcopy(encoder_layer)
|
||||
for _ in range(num_self_predictor_layers)])
|
||||
|
||||
|
||||
self.discrete_bottleneck = DiscreteBottleneck(
|
||||
@ -478,11 +479,11 @@ class BidirectionalConformer(nn.Module):
|
||||
|
||||
def self_prediction_forward(
|
||||
self,
|
||||
memory_shifted: torch.Tensor,
|
||||
bn_memory_shifted: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
sampled: torch.Tensor,
|
||||
softmax: Optional[torch.Tensor],
|
||||
reverse_gradient: bool = True) -> Tensor:
|
||||
reverse_grad: bool = True) -> Tensor:
|
||||
"""
|
||||
Returns the total log-prob of the the labels sampled in the discrete
|
||||
bottleneck layer, as predicted using a relatively simple model that
|
||||
@ -490,11 +491,11 @@ class BidirectionalConformer(nn.Module):
|
||||
[Appears on the denominator of an expression for mutual information].
|
||||
|
||||
Args:
|
||||
memory_shifted:
|
||||
It's the output of forward(), with shape [T, N, E], shifted
|
||||
by one so that shifted_memory[t] == memory[t-1], as in:
|
||||
(T, N, E) = memory.shape
|
||||
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
|
||||
bn_memory_shifted:
|
||||
It's the bn_memory output of forward(), with shape [T, N, E], shifted
|
||||
by one so that bn_shifted_memory[t] == bn_memory[t-1], as in:
|
||||
(T, N, E) = bn_memory.shape
|
||||
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder, of shape [N, T], boolean, True
|
||||
for masked locations.
|
||||
@ -503,41 +504,41 @@ class BidirectionalConformer(nn.Module):
|
||||
as given to the constructor. This will be needed for the 'reverse'
|
||||
model.
|
||||
softmax: is a "soft" version of `sampled`; if None, will default to `sampled`.
|
||||
reverse_gradient: will likely be true. If true, the gradient is reversed twice
|
||||
reverse_grad: will likely be true. If true, the gradient is reversed twice
|
||||
in this computation, so that we train predictors with the correct
|
||||
gradient, i.e. to predict, not anti-predict (since the return value
|
||||
of this function will appear with positive, not negative, sign in the
|
||||
loss function, so will be minimized).
|
||||
The gradient w.r.t. the non-self inputs to this function, though (i.e.
|
||||
memory_shifted, sampled, softmax) will not be reversed, though.
|
||||
bn_memory_shifted, sampled, softmax) will not be reversed, though.
|
||||
Returns:
|
||||
A scalar tensor, the **sum** of label smoothing loss over utterances
|
||||
in the batch without any normalization.
|
||||
"""
|
||||
|
||||
if reverse_gradient:
|
||||
# Reversing gradient for memory_shifted puts the gradient back into
|
||||
if reverse_grad:
|
||||
# Reversing gradient for bn_memory_shifted puts the gradient back into
|
||||
# the correct sign; we reversed it in
|
||||
# self.discrete_bottleneck.compute_prob(), in order that
|
||||
# self.self.predictor_encoder can learn to predict.
|
||||
# (You have to read the code in reverse, to reason about
|
||||
# what happens to the gradients).
|
||||
memory_shifted = ReverseGrad.apply(memory_shifted)
|
||||
bn_memory_shifted = ReverseGrad.apply(bn_memory_shifted)
|
||||
|
||||
# no mask is needed for self_predictor_encoder; its CNN
|
||||
# layer uses left-padding only, making it causal.
|
||||
predictor = self.self_predictor_encoder(memory_shifted)
|
||||
predictor = self.self_predictor_encoder(bn_memory_shifted)
|
||||
|
||||
prob = self.discrete_bottleneck.compute_prob(predictor,
|
||||
sampled, softmax,
|
||||
memory_key_padding_mask,
|
||||
reverse_gradient=True)
|
||||
reverse_grad=True)
|
||||
return prob
|
||||
|
||||
|
||||
def reverse_decoder_forward(
|
||||
self,
|
||||
memory_shifted: torch.Tensor,
|
||||
bn_memory_shifted: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
sampled: torch.Tensor,
|
||||
softmax: Optional[torch.Tensor],
|
||||
@ -552,11 +553,11 @@ class BidirectionalConformer(nn.Module):
|
||||
supervision word-sequence.
|
||||
|
||||
Args:
|
||||
memory_shifted:
|
||||
It's the output of forward(), with shape [T, N, E], shifted
|
||||
by one so that shifted_memory[t] == memory[t-1], as in:
|
||||
bn_memory_shifted:
|
||||
It's the bn_memory output of forward(), with shape [T, N, E], shifted
|
||||
by one so that shifted_memory[t] == bn_memory[t-1], as in:
|
||||
(T, N, E) = memory.shape
|
||||
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
|
||||
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder.
|
||||
sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes`
|
||||
@ -582,7 +583,7 @@ class BidirectionalConformer(nn.Module):
|
||||
token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ]
|
||||
|
||||
tokens_padded = pad_sequence(token_ids_tensors, batch_first=True,
|
||||
padding_value=padding_id).to(memory_shifted.device)
|
||||
padding_value=padding_id).to(bn_memory_shifted.device)
|
||||
|
||||
print("tokens_padded = ", tokens_padded)
|
||||
tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id)
|
||||
@ -600,12 +601,12 @@ class BidirectionalConformer(nn.Module):
|
||||
src_key_padding_mask=tokens_key_padding_mask)
|
||||
# token_memory is of shape (S, N, C), if S is length of token sequence.
|
||||
|
||||
T = memory_shifted.shape[0]
|
||||
T = bn_memory_shifted.shape[0]
|
||||
# the targets, here, are the hidden discrete symbols we are predicting
|
||||
tgt_mask = generate_square_subsequent_mask(T, device=memory_shifted.device)
|
||||
tgt_mask = generate_square_subsequent_mask(T, device=bn_memory_shifted.device)
|
||||
|
||||
hidden_predictor = self.reverse_decoder(
|
||||
tgt=memory_shifted,
|
||||
tgt=bn_memory_shifted,
|
||||
memory=token_memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_key_padding_mask=tokens_key_padding_mask)
|
||||
@ -780,7 +781,7 @@ class DiscreteBottleneck(nn.Module):
|
||||
torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5))
|
||||
|
||||
|
||||
def forward(self, x: Tensor, need_softmax: bool = False) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
def forward(self, x: Tensor, need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Forward computation.
|
||||
Args:
|
||||
@ -813,7 +814,7 @@ class DiscreteBottleneck(nn.Module):
|
||||
|
||||
# This is a little wasteful since we already compute the softmax
|
||||
# inside 'flow_sample'.
|
||||
softmax = x.softmax().reshape(S, N, tot_classes) if need_softmax else None
|
||||
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
|
||||
|
||||
x = torch_flow_sampling.flow_sample(x,
|
||||
interp_prob=self.interp_prob,
|
||||
@ -829,15 +830,15 @@ class DiscreteBottleneck(nn.Module):
|
||||
self.class_probs = (self.class_probs * self.class_probs_decay +
|
||||
mean_class_probs * (1.0 - self.class_probs_decay))
|
||||
prob_floor = self.min_prob_ratio / self.classes_per_group
|
||||
self.class_offsets += (self.class_probs > prob_floor) * self.prob_boost
|
||||
self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost
|
||||
|
||||
embedding = self.norm_out(self.linear2(x))
|
||||
return (embedding, sampled, softmax)
|
||||
|
||||
|
||||
def compute_prob(self, x: Tensor, sampled: Tensor, softmax: Optional[Tensor],
|
||||
padding_mask: Optional[Tensor],
|
||||
reverse_gradient: bool = False) -> Tensor:
|
||||
padding_mask: Optional[Tensor] = None,
|
||||
reverse_grad: bool = False) -> Tensor:
|
||||
"""
|
||||
Compute the total probability of the sampled probabilities, given
|
||||
some kind of predictor x (which we assume should not have access
|
||||
@ -846,7 +847,8 @@ class DiscreteBottleneck(nn.Module):
|
||||
|
||||
x: The predictor tensor, of shape (S, N, E) where S is the
|
||||
sequence length, N is the batch size and E is the embedding dim
|
||||
(`dim` arg to __init__())
|
||||
(`dim` arg to __init__()). This is projected from `sampled`
|
||||
with a learnable matrix.
|
||||
sampled: A tensor of shape (S, N, C) where C is the `tot_classes`
|
||||
to the constructor, containing the sampled probabilities.
|
||||
softmax: A tensor of shape (S, N, C), this is the "smooth" version
|
||||
@ -857,7 +859,7 @@ class DiscreteBottleneck(nn.Module):
|
||||
(batch_size, sequence_length), with True in masked positions
|
||||
that are to be ignored in the sum of probabilities.
|
||||
|
||||
reverse_gradient: If true, negate the gradient that is passed back
|
||||
reverse_grad: If true, negate the gradient that is passed back
|
||||
to 'x' and to the modules self.pred_linear and pred_cross.
|
||||
This will be useful in computing a loss function that has
|
||||
a likelihood term with negative sign (i.e. the self-prediction).
|
||||
@ -867,7 +869,7 @@ class DiscreteBottleneck(nn.Module):
|
||||
|
||||
Returns a scalar Tensor represnting the total probability.
|
||||
"""
|
||||
if reverse_gradient:
|
||||
if reverse_grad:
|
||||
sampled = ReverseGrad.apply(sampled)
|
||||
if softmax is None:
|
||||
softmax = sampled
|
||||
@ -906,7 +908,7 @@ class DiscreteBottleneck(nn.Module):
|
||||
else:
|
||||
tot_prob = (logprobs * softmax).sum()
|
||||
|
||||
if reverse_gradient:
|
||||
if reverse_grad:
|
||||
tot_prob = ReverseGrad.apply(tot_prob)
|
||||
return tot_prob
|
||||
|
||||
@ -1789,18 +1791,18 @@ def _test_bidirectional_conformer():
|
||||
eos_id=2)
|
||||
print("decoder logprob = ", decoder_logprob)
|
||||
|
||||
(T, N, E) = memory.shape
|
||||
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
|
||||
(T, N, E) = bn_memory.shape
|
||||
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
|
||||
|
||||
reverse_decoder_logprob = m.reverse_decoder_forward(
|
||||
memory_shifted, key_padding_mask,
|
||||
bn_memory_shifted, key_padding_mask,
|
||||
sampled, softmax, tokens,
|
||||
sos_id=1, eos_id=2, padding_id=0)
|
||||
|
||||
print("reverse decoder logprob = ", reverse_decoder_logprob)
|
||||
|
||||
self_prediction_logprob = m.self_prediction_forward(
|
||||
memory_shifted, key_padding_mask,
|
||||
bn_memory_shifted, key_padding_mask,
|
||||
sampled, softmax)
|
||||
|
||||
print("self prediction logprob = ", self_prediction_logprob)
|
||||
@ -1809,5 +1811,67 @@ def _test_bidirectional_conformer():
|
||||
loss.backward()
|
||||
|
||||
|
||||
def _test_discrete_bottleneck():
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
dim = 128
|
||||
tot_classes = 256
|
||||
num_groups = 4
|
||||
interp_prob = 0.8
|
||||
straight_through_scale = 0.0 # will change
|
||||
need_softmax = True
|
||||
|
||||
b = DiscreteBottleneck(dim, tot_classes, num_groups,
|
||||
interp_prob, straight_through_scale).to(device)
|
||||
|
||||
self_predictor = nn.Linear(tot_classes, dim).to(device)
|
||||
|
||||
|
||||
optim = torch.optim.SGD(params=(list(b.parameters()) + list(self_predictor.parameters())),
|
||||
lr=1.0e-03, momentum=0.99)
|
||||
|
||||
|
||||
for epoch in range(10):
|
||||
state_dict = dict()
|
||||
state_dict['b'] = b.state_dict()
|
||||
state_dict['s'] = self_predictor.state_dict()
|
||||
torch.save(state_dict, f'epoch-{epoch}.pt')
|
||||
for i in range(1000):
|
||||
# TODO: also test padding_mask
|
||||
T = 300
|
||||
N = 10
|
||||
|
||||
feats = torch.randn(T, N, dim, device=device)
|
||||
bn_memory, sampled, softmax = b(feats)
|
||||
|
||||
predictor = feats # Could also use `bn_memory`, perhaps. But using
|
||||
# predictor because it contains everything, will give
|
||||
# us a bound on max information..
|
||||
prob = b.compute_prob(feats, sampled, softmax)
|
||||
|
||||
sampled_reversed = ReverseGrad.apply(sampled)
|
||||
predictor_reversed = self_predictor(sampled_reversed)
|
||||
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
|
||||
predictor_reversed[:-1,:,:]), dim=0)
|
||||
|
||||
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
|
||||
reverse_grad=True)
|
||||
|
||||
normalized_prob = (prob / (T * N)).to('cpu').item()
|
||||
normalized_self_prob = (self_prob / (T * N)).to('cpu').item()
|
||||
|
||||
loss = -(prob - self_prob)
|
||||
|
||||
normalized_loss = loss / (T * N)
|
||||
|
||||
if i % 200 == 0:
|
||||
print(f"Epoch {epoch}, iteration {i}, normalized loss/frame is {-normalized_prob} - {-normalized_self_prob} = {normalized_loss.to('cpu').item()}")
|
||||
|
||||
normalized_loss.backward()
|
||||
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_bidirectional_conformer()
|
||||
_test_discrete_bottleneck()
|
||||
|
Loading…
x
Reference in New Issue
Block a user