mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Add some more debug stuff: seems like things move around too fast for negative branch to track..
This commit is contained in:
parent
ed84795b47
commit
656de090bd
@ -311,10 +311,11 @@ class BidirectionalConformer(nn.Module):
|
||||
if num_self_predictor_layers > 0:
|
||||
encoder_layer = SimpleCausalEncoderLayer(d_model,
|
||||
dropout=dropout)
|
||||
final_linear = nn.Linear(d_model, d_model)
|
||||
final_linear = nn.Linear(d_model, d_model, bias=False)
|
||||
self.self_predictor_encoder = nn.Sequential(*[copy.deepcopy(encoder_layer)
|
||||
for _ in range(num_self_predictor_layers)],
|
||||
final_linear)
|
||||
final_linear,
|
||||
FastOffsetLayer(d_model))
|
||||
|
||||
|
||||
self.sample_and_predict = SampleAndPredict(
|
||||
@ -604,6 +605,28 @@ class BidirectionalConformer(nn.Module):
|
||||
return total_prob
|
||||
|
||||
|
||||
class FastOffsetLayer(nn.Module):
|
||||
"""
|
||||
A layer that rapidly learns an offset/bias on its output
|
||||
"""
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
bias_scale: float = 100.0):
|
||||
super(FastOffsetLayer, self).__init__()
|
||||
self.bias = nn.Parameter(torch.zeros(dim))
|
||||
self.bias_scale = bias_scale
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
An offset is added, treating the last dim of x as the channel dim.
|
||||
"""
|
||||
if random.random() < 0.005:
|
||||
print("bias = ", self.bias)
|
||||
return x + self.bias * self.bias_scale
|
||||
|
||||
|
||||
|
||||
|
||||
class SimpleCausalEncoderLayer(nn.Module):
|
||||
"""
|
||||
This is a simple encoder layer that only sees left-context; it is
|
||||
@ -816,7 +839,7 @@ class SampleAndPredict(nn.Module):
|
||||
something whose gradient is already (going to be) reversed:
|
||||
specifically, the self-prediction network.
|
||||
"""
|
||||
x = self.linear1(x)
|
||||
x = self.linear1(x) * 3
|
||||
|
||||
if self.min_prob_ratio > 0.0:
|
||||
x = x + self.class_offsets
|
||||
@ -829,7 +852,7 @@ class SampleAndPredict(nn.Module):
|
||||
# 'flow_sample'.
|
||||
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
|
||||
|
||||
if random.random() < 0.01:
|
||||
if random.random() < 0.05:
|
||||
# Some info that's useful for debug.
|
||||
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
|
||||
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
|
||||
@ -841,7 +864,8 @@ class SampleAndPredict(nn.Module):
|
||||
|
||||
print("SampleAndPredict: entropy = ",
|
||||
-negentropy.to('cpu').item(), ", averaged entropy = ",
|
||||
-global_negentropy.to('cpu').item())
|
||||
-global_negentropy.to('cpu').item(),
|
||||
", argmax = ", (global_softmax * global_log_softmax).argmax(dim=-1).to('cpu'))
|
||||
|
||||
|
||||
x = torch_flow_sampling.flow_sample(x,
|
||||
|
Loading…
x
Reference in New Issue
Block a user