Trying to figure out why it's not converging..

This commit is contained in:
Daniel Povey 2021-09-20 13:18:46 +08:00
parent 39b6879d72
commit 2bad68a8ed
2 changed files with 11 additions and 15 deletions

View File

@ -291,8 +291,6 @@ class BidirectionalConformer(nn.Module):
if num_reverse_decoder_layers > 0:
self.reverse_decoder_pos = PositionalEncoding(d_model, dropout)
decoder_layer = TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
@ -313,8 +311,10 @@ class BidirectionalConformer(nn.Module):
if num_self_predictor_layers > 0:
encoder_layer = SimpleCausalEncoderLayer(d_model,
dropout=dropout)
final_linear = nn.Linear(d_model, d_model)
self.self_predictor_encoder = nn.Sequential(*[copy.deepcopy(encoder_layer)
for _ in range(num_self_predictor_layers)])
for _ in range(num_self_predictor_layers)],
final_linear)
self.sample_and_predict = SampleAndPredict(
@ -371,8 +371,8 @@ class BidirectionalConformer(nn.Module):
(T, N, E) = memory.shape
device = memory.device
zeros = torch.zeros(1, N, E).to(memory.device)
negative_embed_shifted = torch.cat((zeros, negative_embed[:-1,:,:]), dim=0)
positive_embed_shifted = torch.cat((zeros, positive_embed[:-1,:,:]), dim=0)
negative_embed_shifted = torch.cat((zeros, negative_embed[:-1,:,:]), dim=0)
return (sampled, softmax, positive_embed_shifted, negative_embed_shifted)
@ -509,7 +509,6 @@ class BidirectionalConformer(nn.Module):
A scalar tensor, the **sum** of the log-prob loss over utterances
in the batch without any normalization.
"""
# no mask is needed for self_predictor_encoder; its CNN
# layer uses left-padding only, making it causal, so the mask
# is redundant (it wouldn't affect any of the
@ -630,9 +629,7 @@ class SimpleCausalEncoderLayer(nn.Module):
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.dropout = nn.Dropout(dropout)
@ -743,7 +740,7 @@ class SampleAndPredict(nn.Module):
self.classes_per_group = tot_classes // num_groups
# prob_boost relates to the min_prob_ratio setting. It's not configurable for now.
self.prob_boost = 1.0e-05
self.prob_boost = 1.0e-03
# class_probs is a rolling mean of the output of the sampling operation.
# When any element of it gets below self.min_prob_ratio / self.classes_per_group,
@ -819,8 +816,7 @@ class SampleAndPredict(nn.Module):
something whose gradient is already (going to be) reversed:
specifically, the self-prediction network.
"""
x = self.linear1(x * 5) # multiplying 5 gives lower entropy, makes it
# begin training faster..
x = self.linear1(x)
if self.min_prob_ratio > 0.0:
x = x + self.class_offsets
@ -833,7 +829,7 @@ class SampleAndPredict(nn.Module):
# 'flow_sample'.
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
if random.random() < 0.001:
if random.random() < 0.01:
# Some info that's useful for debug.
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
@ -949,7 +945,7 @@ class SampleAndPredict(nn.Module):
if reverse_grad:
tot_prob = reverse_gradient(tot_prob)
return tot_prob
return tot_prob
class ConformerEncoderLayer(nn.Module):

View File

@ -464,7 +464,7 @@ def compute_loss(
# Will eventually remove this block..
num_frames = supervision_segments[:, 2].sum().item()
print(f"Self-prediction logprob = {self_prediction_logprob/num_frames}, "
f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}"
f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}, "
f"reverse_att_loss = {reverse_att_loss/num_frames}")
else:
reverse_att_loss = torch.tensor([0.0]).to(device)
@ -578,7 +578,7 @@ def train_one_epoch(
graph_compiler=graph_compiler,
is_training=True,
)
tot_loss = (tot_loss * (1 + 1 / params.reset_interval)) + loss_info # summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # summary stats.
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.