Add some more debug stuff: seems like things move around too fast for negative branch to track..

This commit is contained in:
Daniel Povey 2021-09-20 16:11:30 +08:00
parent ed84795b47
commit 656de090bd

View File

@ -311,10 +311,11 @@ class BidirectionalConformer(nn.Module):
if num_self_predictor_layers > 0:
encoder_layer = SimpleCausalEncoderLayer(d_model,
dropout=dropout)
final_linear = nn.Linear(d_model, d_model)
final_linear = nn.Linear(d_model, d_model, bias=False)
self.self_predictor_encoder = nn.Sequential(*[copy.deepcopy(encoder_layer)
for _ in range(num_self_predictor_layers)],
final_linear)
final_linear,
FastOffsetLayer(d_model))
self.sample_and_predict = SampleAndPredict(
@ -604,6 +605,28 @@ class BidirectionalConformer(nn.Module):
return total_prob
class FastOffsetLayer(nn.Module):
"""
A layer that rapidly learns an offset/bias on its output
"""
def __init__(self,
dim: int,
bias_scale: float = 100.0):
super(FastOffsetLayer, self).__init__()
self.bias = nn.Parameter(torch.zeros(dim))
self.bias_scale = bias_scale
def forward(self, x):
"""
An offset is added, treating the last dim of x as the channel dim.
"""
if random.random() < 0.005:
print("bias = ", self.bias)
return x + self.bias * self.bias_scale
class SimpleCausalEncoderLayer(nn.Module):
"""
This is a simple encoder layer that only sees left-context; it is
@ -816,7 +839,7 @@ class SampleAndPredict(nn.Module):
something whose gradient is already (going to be) reversed:
specifically, the self-prediction network.
"""
x = self.linear1(x)
x = self.linear1(x) * 3
if self.min_prob_ratio > 0.0:
x = x + self.class_offsets
@ -829,7 +852,7 @@ class SampleAndPredict(nn.Module):
# 'flow_sample'.
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
if random.random() < 0.01:
if random.random() < 0.05:
# Some info that's useful for debug.
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
@ -841,7 +864,8 @@ class SampleAndPredict(nn.Module):
print("SampleAndPredict: entropy = ",
-negentropy.to('cpu').item(), ", averaged entropy = ",
-global_negentropy.to('cpu').item())
-global_negentropy.to('cpu').item(),
", argmax = ", (global_softmax * global_log_softmax).argmax(dim=-1).to('cpu'))
x = torch_flow_sampling.flow_sample(x,