mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Trying to figure out why it's not converging..
This commit is contained in:
parent
39b6879d72
commit
2bad68a8ed
@ -291,8 +291,6 @@ class BidirectionalConformer(nn.Module):
|
|||||||
|
|
||||||
if num_reverse_decoder_layers > 0:
|
if num_reverse_decoder_layers > 0:
|
||||||
|
|
||||||
self.reverse_decoder_pos = PositionalEncoding(d_model, dropout)
|
|
||||||
|
|
||||||
decoder_layer = TransformerDecoderLayer(
|
decoder_layer = TransformerDecoderLayer(
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
nhead=nhead,
|
nhead=nhead,
|
||||||
@ -313,8 +311,10 @@ class BidirectionalConformer(nn.Module):
|
|||||||
if num_self_predictor_layers > 0:
|
if num_self_predictor_layers > 0:
|
||||||
encoder_layer = SimpleCausalEncoderLayer(d_model,
|
encoder_layer = SimpleCausalEncoderLayer(d_model,
|
||||||
dropout=dropout)
|
dropout=dropout)
|
||||||
|
final_linear = nn.Linear(d_model, d_model)
|
||||||
self.self_predictor_encoder = nn.Sequential(*[copy.deepcopy(encoder_layer)
|
self.self_predictor_encoder = nn.Sequential(*[copy.deepcopy(encoder_layer)
|
||||||
for _ in range(num_self_predictor_layers)])
|
for _ in range(num_self_predictor_layers)],
|
||||||
|
final_linear)
|
||||||
|
|
||||||
|
|
||||||
self.sample_and_predict = SampleAndPredict(
|
self.sample_and_predict = SampleAndPredict(
|
||||||
@ -371,8 +371,8 @@ class BidirectionalConformer(nn.Module):
|
|||||||
(T, N, E) = memory.shape
|
(T, N, E) = memory.shape
|
||||||
device = memory.device
|
device = memory.device
|
||||||
zeros = torch.zeros(1, N, E).to(memory.device)
|
zeros = torch.zeros(1, N, E).to(memory.device)
|
||||||
negative_embed_shifted = torch.cat((zeros, negative_embed[:-1,:,:]), dim=0)
|
|
||||||
positive_embed_shifted = torch.cat((zeros, positive_embed[:-1,:,:]), dim=0)
|
positive_embed_shifted = torch.cat((zeros, positive_embed[:-1,:,:]), dim=0)
|
||||||
|
negative_embed_shifted = torch.cat((zeros, negative_embed[:-1,:,:]), dim=0)
|
||||||
|
|
||||||
return (sampled, softmax, positive_embed_shifted, negative_embed_shifted)
|
return (sampled, softmax, positive_embed_shifted, negative_embed_shifted)
|
||||||
|
|
||||||
@ -509,7 +509,6 @@ class BidirectionalConformer(nn.Module):
|
|||||||
A scalar tensor, the **sum** of the log-prob loss over utterances
|
A scalar tensor, the **sum** of the log-prob loss over utterances
|
||||||
in the batch without any normalization.
|
in the batch without any normalization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# no mask is needed for self_predictor_encoder; its CNN
|
# no mask is needed for self_predictor_encoder; its CNN
|
||||||
# layer uses left-padding only, making it causal, so the mask
|
# layer uses left-padding only, making it causal, so the mask
|
||||||
# is redundant (it wouldn't affect any of the
|
# is redundant (it wouldn't affect any of the
|
||||||
@ -630,9 +629,7 @@ class SimpleCausalEncoderLayer(nn.Module):
|
|||||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||||
self.ff_scale = 0.5
|
self.ff_scale = 0.5
|
||||||
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||||
self.norm_final = nn.LayerNorm(
|
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
|
||||||
d_model
|
|
||||||
) # for the final output of the block
|
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
@ -743,7 +740,7 @@ class SampleAndPredict(nn.Module):
|
|||||||
self.classes_per_group = tot_classes // num_groups
|
self.classes_per_group = tot_classes // num_groups
|
||||||
|
|
||||||
# prob_boost relates to the min_prob_ratio setting. It's not configurable for now.
|
# prob_boost relates to the min_prob_ratio setting. It's not configurable for now.
|
||||||
self.prob_boost = 1.0e-05
|
self.prob_boost = 1.0e-03
|
||||||
|
|
||||||
# class_probs is a rolling mean of the output of the sampling operation.
|
# class_probs is a rolling mean of the output of the sampling operation.
|
||||||
# When any element of it gets below self.min_prob_ratio / self.classes_per_group,
|
# When any element of it gets below self.min_prob_ratio / self.classes_per_group,
|
||||||
@ -819,8 +816,7 @@ class SampleAndPredict(nn.Module):
|
|||||||
something whose gradient is already (going to be) reversed:
|
something whose gradient is already (going to be) reversed:
|
||||||
specifically, the self-prediction network.
|
specifically, the self-prediction network.
|
||||||
"""
|
"""
|
||||||
x = self.linear1(x * 5) # multiplying 5 gives lower entropy, makes it
|
x = self.linear1(x)
|
||||||
# begin training faster..
|
|
||||||
|
|
||||||
if self.min_prob_ratio > 0.0:
|
if self.min_prob_ratio > 0.0:
|
||||||
x = x + self.class_offsets
|
x = x + self.class_offsets
|
||||||
@ -833,7 +829,7 @@ class SampleAndPredict(nn.Module):
|
|||||||
# 'flow_sample'.
|
# 'flow_sample'.
|
||||||
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
|
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
|
||||||
|
|
||||||
if random.random() < 0.001:
|
if random.random() < 0.01:
|
||||||
# Some info that's useful for debug.
|
# Some info that's useful for debug.
|
||||||
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
|
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
|
||||||
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
|
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
|
||||||
@ -949,7 +945,7 @@ class SampleAndPredict(nn.Module):
|
|||||||
|
|
||||||
if reverse_grad:
|
if reverse_grad:
|
||||||
tot_prob = reverse_gradient(tot_prob)
|
tot_prob = reverse_gradient(tot_prob)
|
||||||
return tot_prob
|
return tot_prob
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoderLayer(nn.Module):
|
class ConformerEncoderLayer(nn.Module):
|
||||||
|
@ -464,7 +464,7 @@ def compute_loss(
|
|||||||
# Will eventually remove this block..
|
# Will eventually remove this block..
|
||||||
num_frames = supervision_segments[:, 2].sum().item()
|
num_frames = supervision_segments[:, 2].sum().item()
|
||||||
print(f"Self-prediction logprob = {self_prediction_logprob/num_frames}, "
|
print(f"Self-prediction logprob = {self_prediction_logprob/num_frames}, "
|
||||||
f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}"
|
f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}, "
|
||||||
f"reverse_att_loss = {reverse_att_loss/num_frames}")
|
f"reverse_att_loss = {reverse_att_loss/num_frames}")
|
||||||
else:
|
else:
|
||||||
reverse_att_loss = torch.tensor([0.0]).to(device)
|
reverse_att_loss = torch.tensor([0.0]).to(device)
|
||||||
@ -578,7 +578,7 @@ def train_one_epoch(
|
|||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
tot_loss = (tot_loss * (1 + 1 / params.reset_interval)) + loss_info # summary stats.
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # summary stats.
|
||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user