mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Trying to figure out why it's not converging..
This commit is contained in:
parent
39b6879d72
commit
2bad68a8ed
@ -291,8 +291,6 @@ class BidirectionalConformer(nn.Module):
|
||||
|
||||
if num_reverse_decoder_layers > 0:
|
||||
|
||||
self.reverse_decoder_pos = PositionalEncoding(d_model, dropout)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
@ -313,8 +311,10 @@ class BidirectionalConformer(nn.Module):
|
||||
if num_self_predictor_layers > 0:
|
||||
encoder_layer = SimpleCausalEncoderLayer(d_model,
|
||||
dropout=dropout)
|
||||
final_linear = nn.Linear(d_model, d_model)
|
||||
self.self_predictor_encoder = nn.Sequential(*[copy.deepcopy(encoder_layer)
|
||||
for _ in range(num_self_predictor_layers)])
|
||||
for _ in range(num_self_predictor_layers)],
|
||||
final_linear)
|
||||
|
||||
|
||||
self.sample_and_predict = SampleAndPredict(
|
||||
@ -371,8 +371,8 @@ class BidirectionalConformer(nn.Module):
|
||||
(T, N, E) = memory.shape
|
||||
device = memory.device
|
||||
zeros = torch.zeros(1, N, E).to(memory.device)
|
||||
negative_embed_shifted = torch.cat((zeros, negative_embed[:-1,:,:]), dim=0)
|
||||
positive_embed_shifted = torch.cat((zeros, positive_embed[:-1,:,:]), dim=0)
|
||||
negative_embed_shifted = torch.cat((zeros, negative_embed[:-1,:,:]), dim=0)
|
||||
|
||||
return (sampled, softmax, positive_embed_shifted, negative_embed_shifted)
|
||||
|
||||
@ -509,7 +509,6 @@ class BidirectionalConformer(nn.Module):
|
||||
A scalar tensor, the **sum** of the log-prob loss over utterances
|
||||
in the batch without any normalization.
|
||||
"""
|
||||
|
||||
# no mask is needed for self_predictor_encoder; its CNN
|
||||
# layer uses left-padding only, making it causal, so the mask
|
||||
# is redundant (it wouldn't affect any of the
|
||||
@ -630,9 +629,7 @@ class SimpleCausalEncoderLayer(nn.Module):
|
||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||
self.ff_scale = 0.5
|
||||
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||
self.norm_final = nn.LayerNorm(
|
||||
d_model
|
||||
) # for the final output of the block
|
||||
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@ -743,7 +740,7 @@ class SampleAndPredict(nn.Module):
|
||||
self.classes_per_group = tot_classes // num_groups
|
||||
|
||||
# prob_boost relates to the min_prob_ratio setting. It's not configurable for now.
|
||||
self.prob_boost = 1.0e-05
|
||||
self.prob_boost = 1.0e-03
|
||||
|
||||
# class_probs is a rolling mean of the output of the sampling operation.
|
||||
# When any element of it gets below self.min_prob_ratio / self.classes_per_group,
|
||||
@ -819,8 +816,7 @@ class SampleAndPredict(nn.Module):
|
||||
something whose gradient is already (going to be) reversed:
|
||||
specifically, the self-prediction network.
|
||||
"""
|
||||
x = self.linear1(x * 5) # multiplying 5 gives lower entropy, makes it
|
||||
# begin training faster..
|
||||
x = self.linear1(x)
|
||||
|
||||
if self.min_prob_ratio > 0.0:
|
||||
x = x + self.class_offsets
|
||||
@ -833,7 +829,7 @@ class SampleAndPredict(nn.Module):
|
||||
# 'flow_sample'.
|
||||
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
|
||||
|
||||
if random.random() < 0.001:
|
||||
if random.random() < 0.01:
|
||||
# Some info that's useful for debug.
|
||||
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
|
||||
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
|
||||
@ -949,7 +945,7 @@ class SampleAndPredict(nn.Module):
|
||||
|
||||
if reverse_grad:
|
||||
tot_prob = reverse_gradient(tot_prob)
|
||||
return tot_prob
|
||||
return tot_prob
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
|
@ -464,7 +464,7 @@ def compute_loss(
|
||||
# Will eventually remove this block..
|
||||
num_frames = supervision_segments[:, 2].sum().item()
|
||||
print(f"Self-prediction logprob = {self_prediction_logprob/num_frames}, "
|
||||
f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}"
|
||||
f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}, "
|
||||
f"reverse_att_loss = {reverse_att_loss/num_frames}")
|
||||
else:
|
||||
reverse_att_loss = torch.tensor([0.0]).to(device)
|
||||
@ -578,7 +578,7 @@ def train_one_epoch(
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=True,
|
||||
)
|
||||
tot_loss = (tot_loss * (1 + 1 / params.reset_interval)) + loss_info # summary stats.
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # summary stats.
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
|
Loading…
x
Reference in New Issue
Block a user