mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Some progress in testing..
This commit is contained in:
parent
a20d490332
commit
38081bc3e3
@ -320,7 +320,8 @@ class BidirectionalConformer(nn.Module):
|
|||||||
if num_self_predictor_layers > 0:
|
if num_self_predictor_layers > 0:
|
||||||
encoder_layer = SimpleCausalEncoderLayer(d_model,
|
encoder_layer = SimpleCausalEncoderLayer(d_model,
|
||||||
dropout=dropout)
|
dropout=dropout)
|
||||||
self.self_predictor_encoder = encoder_layer
|
self.self_predictor_encoder = nn.Sequential(*[copy.deepcopy(encoder_layer)
|
||||||
|
for _ in range(num_self_predictor_layers)])
|
||||||
|
|
||||||
|
|
||||||
self.discrete_bottleneck = DiscreteBottleneck(
|
self.discrete_bottleneck = DiscreteBottleneck(
|
||||||
@ -478,11 +479,11 @@ class BidirectionalConformer(nn.Module):
|
|||||||
|
|
||||||
def self_prediction_forward(
|
def self_prediction_forward(
|
||||||
self,
|
self,
|
||||||
memory_shifted: torch.Tensor,
|
bn_memory_shifted: torch.Tensor,
|
||||||
memory_key_padding_mask: torch.Tensor,
|
memory_key_padding_mask: torch.Tensor,
|
||||||
sampled: torch.Tensor,
|
sampled: torch.Tensor,
|
||||||
softmax: Optional[torch.Tensor],
|
softmax: Optional[torch.Tensor],
|
||||||
reverse_gradient: bool = True) -> Tensor:
|
reverse_grad: bool = True) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns the total log-prob of the the labels sampled in the discrete
|
Returns the total log-prob of the the labels sampled in the discrete
|
||||||
bottleneck layer, as predicted using a relatively simple model that
|
bottleneck layer, as predicted using a relatively simple model that
|
||||||
@ -490,11 +491,11 @@ class BidirectionalConformer(nn.Module):
|
|||||||
[Appears on the denominator of an expression for mutual information].
|
[Appears on the denominator of an expression for mutual information].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_shifted:
|
bn_memory_shifted:
|
||||||
It's the output of forward(), with shape [T, N, E], shifted
|
It's the bn_memory output of forward(), with shape [T, N, E], shifted
|
||||||
by one so that shifted_memory[t] == memory[t-1], as in:
|
by one so that bn_shifted_memory[t] == bn_memory[t-1], as in:
|
||||||
(T, N, E) = memory.shape
|
(T, N, E) = bn_memory.shape
|
||||||
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
|
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
|
||||||
memory_key_padding_mask:
|
memory_key_padding_mask:
|
||||||
The padding mask from the encoder, of shape [N, T], boolean, True
|
The padding mask from the encoder, of shape [N, T], boolean, True
|
||||||
for masked locations.
|
for masked locations.
|
||||||
@ -503,41 +504,41 @@ class BidirectionalConformer(nn.Module):
|
|||||||
as given to the constructor. This will be needed for the 'reverse'
|
as given to the constructor. This will be needed for the 'reverse'
|
||||||
model.
|
model.
|
||||||
softmax: is a "soft" version of `sampled`; if None, will default to `sampled`.
|
softmax: is a "soft" version of `sampled`; if None, will default to `sampled`.
|
||||||
reverse_gradient: will likely be true. If true, the gradient is reversed twice
|
reverse_grad: will likely be true. If true, the gradient is reversed twice
|
||||||
in this computation, so that we train predictors with the correct
|
in this computation, so that we train predictors with the correct
|
||||||
gradient, i.e. to predict, not anti-predict (since the return value
|
gradient, i.e. to predict, not anti-predict (since the return value
|
||||||
of this function will appear with positive, not negative, sign in the
|
of this function will appear with positive, not negative, sign in the
|
||||||
loss function, so will be minimized).
|
loss function, so will be minimized).
|
||||||
The gradient w.r.t. the non-self inputs to this function, though (i.e.
|
The gradient w.r.t. the non-self inputs to this function, though (i.e.
|
||||||
memory_shifted, sampled, softmax) will not be reversed, though.
|
bn_memory_shifted, sampled, softmax) will not be reversed, though.
|
||||||
Returns:
|
Returns:
|
||||||
A scalar tensor, the **sum** of label smoothing loss over utterances
|
A scalar tensor, the **sum** of label smoothing loss over utterances
|
||||||
in the batch without any normalization.
|
in the batch without any normalization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if reverse_gradient:
|
if reverse_grad:
|
||||||
# Reversing gradient for memory_shifted puts the gradient back into
|
# Reversing gradient for bn_memory_shifted puts the gradient back into
|
||||||
# the correct sign; we reversed it in
|
# the correct sign; we reversed it in
|
||||||
# self.discrete_bottleneck.compute_prob(), in order that
|
# self.discrete_bottleneck.compute_prob(), in order that
|
||||||
# self.self.predictor_encoder can learn to predict.
|
# self.self.predictor_encoder can learn to predict.
|
||||||
# (You have to read the code in reverse, to reason about
|
# (You have to read the code in reverse, to reason about
|
||||||
# what happens to the gradients).
|
# what happens to the gradients).
|
||||||
memory_shifted = ReverseGrad.apply(memory_shifted)
|
bn_memory_shifted = ReverseGrad.apply(bn_memory_shifted)
|
||||||
|
|
||||||
# no mask is needed for self_predictor_encoder; its CNN
|
# no mask is needed for self_predictor_encoder; its CNN
|
||||||
# layer uses left-padding only, making it causal.
|
# layer uses left-padding only, making it causal.
|
||||||
predictor = self.self_predictor_encoder(memory_shifted)
|
predictor = self.self_predictor_encoder(bn_memory_shifted)
|
||||||
|
|
||||||
prob = self.discrete_bottleneck.compute_prob(predictor,
|
prob = self.discrete_bottleneck.compute_prob(predictor,
|
||||||
sampled, softmax,
|
sampled, softmax,
|
||||||
memory_key_padding_mask,
|
memory_key_padding_mask,
|
||||||
reverse_gradient=True)
|
reverse_grad=True)
|
||||||
return prob
|
return prob
|
||||||
|
|
||||||
|
|
||||||
def reverse_decoder_forward(
|
def reverse_decoder_forward(
|
||||||
self,
|
self,
|
||||||
memory_shifted: torch.Tensor,
|
bn_memory_shifted: torch.Tensor,
|
||||||
memory_key_padding_mask: torch.Tensor,
|
memory_key_padding_mask: torch.Tensor,
|
||||||
sampled: torch.Tensor,
|
sampled: torch.Tensor,
|
||||||
softmax: Optional[torch.Tensor],
|
softmax: Optional[torch.Tensor],
|
||||||
@ -552,11 +553,11 @@ class BidirectionalConformer(nn.Module):
|
|||||||
supervision word-sequence.
|
supervision word-sequence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_shifted:
|
bn_memory_shifted:
|
||||||
It's the output of forward(), with shape [T, N, E], shifted
|
It's the bn_memory output of forward(), with shape [T, N, E], shifted
|
||||||
by one so that shifted_memory[t] == memory[t-1], as in:
|
by one so that shifted_memory[t] == bn_memory[t-1], as in:
|
||||||
(T, N, E) = memory.shape
|
(T, N, E) = memory.shape
|
||||||
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
|
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
|
||||||
memory_key_padding_mask:
|
memory_key_padding_mask:
|
||||||
The padding mask from the encoder.
|
The padding mask from the encoder.
|
||||||
sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes`
|
sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes`
|
||||||
@ -582,7 +583,7 @@ class BidirectionalConformer(nn.Module):
|
|||||||
token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ]
|
token_ids_tensors = [ torch.tensor([sos_id] + utt + [eos_id]) for utt in token_ids ]
|
||||||
|
|
||||||
tokens_padded = pad_sequence(token_ids_tensors, batch_first=True,
|
tokens_padded = pad_sequence(token_ids_tensors, batch_first=True,
|
||||||
padding_value=padding_id).to(memory_shifted.device)
|
padding_value=padding_id).to(bn_memory_shifted.device)
|
||||||
|
|
||||||
print("tokens_padded = ", tokens_padded)
|
print("tokens_padded = ", tokens_padded)
|
||||||
tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id)
|
tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id)
|
||||||
@ -600,12 +601,12 @@ class BidirectionalConformer(nn.Module):
|
|||||||
src_key_padding_mask=tokens_key_padding_mask)
|
src_key_padding_mask=tokens_key_padding_mask)
|
||||||
# token_memory is of shape (S, N, C), if S is length of token sequence.
|
# token_memory is of shape (S, N, C), if S is length of token sequence.
|
||||||
|
|
||||||
T = memory_shifted.shape[0]
|
T = bn_memory_shifted.shape[0]
|
||||||
# the targets, here, are the hidden discrete symbols we are predicting
|
# the targets, here, are the hidden discrete symbols we are predicting
|
||||||
tgt_mask = generate_square_subsequent_mask(T, device=memory_shifted.device)
|
tgt_mask = generate_square_subsequent_mask(T, device=bn_memory_shifted.device)
|
||||||
|
|
||||||
hidden_predictor = self.reverse_decoder(
|
hidden_predictor = self.reverse_decoder(
|
||||||
tgt=memory_shifted,
|
tgt=bn_memory_shifted,
|
||||||
memory=token_memory,
|
memory=token_memory,
|
||||||
tgt_mask=tgt_mask,
|
tgt_mask=tgt_mask,
|
||||||
memory_key_padding_mask=tokens_key_padding_mask)
|
memory_key_padding_mask=tokens_key_padding_mask)
|
||||||
@ -780,7 +781,7 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5))
|
torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5))
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, need_softmax: bool = False) -> Tuple[Tensor, Tensor, Tensor]:
|
def forward(self, x: Tensor, need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Forward computation.
|
Forward computation.
|
||||||
Args:
|
Args:
|
||||||
@ -813,7 +814,7 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
|
|
||||||
# This is a little wasteful since we already compute the softmax
|
# This is a little wasteful since we already compute the softmax
|
||||||
# inside 'flow_sample'.
|
# inside 'flow_sample'.
|
||||||
softmax = x.softmax().reshape(S, N, tot_classes) if need_softmax else None
|
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
|
||||||
|
|
||||||
x = torch_flow_sampling.flow_sample(x,
|
x = torch_flow_sampling.flow_sample(x,
|
||||||
interp_prob=self.interp_prob,
|
interp_prob=self.interp_prob,
|
||||||
@ -829,15 +830,15 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
self.class_probs = (self.class_probs * self.class_probs_decay +
|
self.class_probs = (self.class_probs * self.class_probs_decay +
|
||||||
mean_class_probs * (1.0 - self.class_probs_decay))
|
mean_class_probs * (1.0 - self.class_probs_decay))
|
||||||
prob_floor = self.min_prob_ratio / self.classes_per_group
|
prob_floor = self.min_prob_ratio / self.classes_per_group
|
||||||
self.class_offsets += (self.class_probs > prob_floor) * self.prob_boost
|
self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost
|
||||||
|
|
||||||
embedding = self.norm_out(self.linear2(x))
|
embedding = self.norm_out(self.linear2(x))
|
||||||
return (embedding, sampled, softmax)
|
return (embedding, sampled, softmax)
|
||||||
|
|
||||||
|
|
||||||
def compute_prob(self, x: Tensor, sampled: Tensor, softmax: Optional[Tensor],
|
def compute_prob(self, x: Tensor, sampled: Tensor, softmax: Optional[Tensor],
|
||||||
padding_mask: Optional[Tensor],
|
padding_mask: Optional[Tensor] = None,
|
||||||
reverse_gradient: bool = False) -> Tensor:
|
reverse_grad: bool = False) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Compute the total probability of the sampled probabilities, given
|
Compute the total probability of the sampled probabilities, given
|
||||||
some kind of predictor x (which we assume should not have access
|
some kind of predictor x (which we assume should not have access
|
||||||
@ -846,7 +847,8 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
|
|
||||||
x: The predictor tensor, of shape (S, N, E) where S is the
|
x: The predictor tensor, of shape (S, N, E) where S is the
|
||||||
sequence length, N is the batch size and E is the embedding dim
|
sequence length, N is the batch size and E is the embedding dim
|
||||||
(`dim` arg to __init__())
|
(`dim` arg to __init__()). This is projected from `sampled`
|
||||||
|
with a learnable matrix.
|
||||||
sampled: A tensor of shape (S, N, C) where C is the `tot_classes`
|
sampled: A tensor of shape (S, N, C) where C is the `tot_classes`
|
||||||
to the constructor, containing the sampled probabilities.
|
to the constructor, containing the sampled probabilities.
|
||||||
softmax: A tensor of shape (S, N, C), this is the "smooth" version
|
softmax: A tensor of shape (S, N, C), this is the "smooth" version
|
||||||
@ -857,7 +859,7 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
(batch_size, sequence_length), with True in masked positions
|
(batch_size, sequence_length), with True in masked positions
|
||||||
that are to be ignored in the sum of probabilities.
|
that are to be ignored in the sum of probabilities.
|
||||||
|
|
||||||
reverse_gradient: If true, negate the gradient that is passed back
|
reverse_grad: If true, negate the gradient that is passed back
|
||||||
to 'x' and to the modules self.pred_linear and pred_cross.
|
to 'x' and to the modules self.pred_linear and pred_cross.
|
||||||
This will be useful in computing a loss function that has
|
This will be useful in computing a loss function that has
|
||||||
a likelihood term with negative sign (i.e. the self-prediction).
|
a likelihood term with negative sign (i.e. the self-prediction).
|
||||||
@ -867,7 +869,7 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
|
|
||||||
Returns a scalar Tensor represnting the total probability.
|
Returns a scalar Tensor represnting the total probability.
|
||||||
"""
|
"""
|
||||||
if reverse_gradient:
|
if reverse_grad:
|
||||||
sampled = ReverseGrad.apply(sampled)
|
sampled = ReverseGrad.apply(sampled)
|
||||||
if softmax is None:
|
if softmax is None:
|
||||||
softmax = sampled
|
softmax = sampled
|
||||||
@ -906,7 +908,7 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
else:
|
else:
|
||||||
tot_prob = (logprobs * softmax).sum()
|
tot_prob = (logprobs * softmax).sum()
|
||||||
|
|
||||||
if reverse_gradient:
|
if reverse_grad:
|
||||||
tot_prob = ReverseGrad.apply(tot_prob)
|
tot_prob = ReverseGrad.apply(tot_prob)
|
||||||
return tot_prob
|
return tot_prob
|
||||||
|
|
||||||
@ -1789,18 +1791,18 @@ def _test_bidirectional_conformer():
|
|||||||
eos_id=2)
|
eos_id=2)
|
||||||
print("decoder logprob = ", decoder_logprob)
|
print("decoder logprob = ", decoder_logprob)
|
||||||
|
|
||||||
(T, N, E) = memory.shape
|
(T, N, E) = bn_memory.shape
|
||||||
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
|
bn_memory_shifted = torch.cat((torch.zeros(1, N, E), bn_memory[:-1,:,:]), dim=0)
|
||||||
|
|
||||||
reverse_decoder_logprob = m.reverse_decoder_forward(
|
reverse_decoder_logprob = m.reverse_decoder_forward(
|
||||||
memory_shifted, key_padding_mask,
|
bn_memory_shifted, key_padding_mask,
|
||||||
sampled, softmax, tokens,
|
sampled, softmax, tokens,
|
||||||
sos_id=1, eos_id=2, padding_id=0)
|
sos_id=1, eos_id=2, padding_id=0)
|
||||||
|
|
||||||
print("reverse decoder logprob = ", reverse_decoder_logprob)
|
print("reverse decoder logprob = ", reverse_decoder_logprob)
|
||||||
|
|
||||||
self_prediction_logprob = m.self_prediction_forward(
|
self_prediction_logprob = m.self_prediction_forward(
|
||||||
memory_shifted, key_padding_mask,
|
bn_memory_shifted, key_padding_mask,
|
||||||
sampled, softmax)
|
sampled, softmax)
|
||||||
|
|
||||||
print("self prediction logprob = ", self_prediction_logprob)
|
print("self prediction logprob = ", self_prediction_logprob)
|
||||||
@ -1809,5 +1811,67 @@ def _test_bidirectional_conformer():
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|
||||||
|
def _test_discrete_bottleneck():
|
||||||
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
dim = 128
|
||||||
|
tot_classes = 256
|
||||||
|
num_groups = 4
|
||||||
|
interp_prob = 0.8
|
||||||
|
straight_through_scale = 0.0 # will change
|
||||||
|
need_softmax = True
|
||||||
|
|
||||||
|
b = DiscreteBottleneck(dim, tot_classes, num_groups,
|
||||||
|
interp_prob, straight_through_scale).to(device)
|
||||||
|
|
||||||
|
self_predictor = nn.Linear(tot_classes, dim).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
optim = torch.optim.SGD(params=(list(b.parameters()) + list(self_predictor.parameters())),
|
||||||
|
lr=1.0e-03, momentum=0.99)
|
||||||
|
|
||||||
|
|
||||||
|
for epoch in range(10):
|
||||||
|
state_dict = dict()
|
||||||
|
state_dict['b'] = b.state_dict()
|
||||||
|
state_dict['s'] = self_predictor.state_dict()
|
||||||
|
torch.save(state_dict, f'epoch-{epoch}.pt')
|
||||||
|
for i in range(1000):
|
||||||
|
# TODO: also test padding_mask
|
||||||
|
T = 300
|
||||||
|
N = 10
|
||||||
|
|
||||||
|
feats = torch.randn(T, N, dim, device=device)
|
||||||
|
bn_memory, sampled, softmax = b(feats)
|
||||||
|
|
||||||
|
predictor = feats # Could also use `bn_memory`, perhaps. But using
|
||||||
|
# predictor because it contains everything, will give
|
||||||
|
# us a bound on max information..
|
||||||
|
prob = b.compute_prob(feats, sampled, softmax)
|
||||||
|
|
||||||
|
sampled_reversed = ReverseGrad.apply(sampled)
|
||||||
|
predictor_reversed = self_predictor(sampled_reversed)
|
||||||
|
predictor_reversed_shifted = torch.cat((torch.zeros(1, N, dim).to(device),
|
||||||
|
predictor_reversed[:-1,:,:]), dim=0)
|
||||||
|
|
||||||
|
self_prob = b.compute_prob(predictor_reversed_shifted, sampled, softmax,
|
||||||
|
reverse_grad=True)
|
||||||
|
|
||||||
|
normalized_prob = (prob / (T * N)).to('cpu').item()
|
||||||
|
normalized_self_prob = (self_prob / (T * N)).to('cpu').item()
|
||||||
|
|
||||||
|
loss = -(prob - self_prob)
|
||||||
|
|
||||||
|
normalized_loss = loss / (T * N)
|
||||||
|
|
||||||
|
if i % 200 == 0:
|
||||||
|
print(f"Epoch {epoch}, iteration {i}, normalized loss/frame is {-normalized_prob} - {-normalized_self_prob} = {normalized_loss.to('cpu').item()}")
|
||||||
|
|
||||||
|
normalized_loss.backward()
|
||||||
|
|
||||||
|
optim.step()
|
||||||
|
optim.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
_test_bidirectional_conformer()
|
_test_bidirectional_conformer()
|
||||||
|
_test_discrete_bottleneck()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user