diff --git a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py index bdb9fcb01..010f97155 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn_2d/conformer.py @@ -311,10 +311,11 @@ class BidirectionalConformer(nn.Module): if num_self_predictor_layers > 0: encoder_layer = SimpleCausalEncoderLayer(d_model, dropout=dropout) - final_linear = nn.Linear(d_model, d_model) + final_linear = nn.Linear(d_model, d_model, bias=False) self.self_predictor_encoder = nn.Sequential(*[copy.deepcopy(encoder_layer) for _ in range(num_self_predictor_layers)], - final_linear) + final_linear, + FastOffsetLayer(d_model)) self.sample_and_predict = SampleAndPredict( @@ -604,6 +605,28 @@ class BidirectionalConformer(nn.Module): return total_prob +class FastOffsetLayer(nn.Module): + """ + A layer that rapidly learns an offset/bias on its output + """ + def __init__(self, + dim: int, + bias_scale: float = 100.0): + super(FastOffsetLayer, self).__init__() + self.bias = nn.Parameter(torch.zeros(dim)) + self.bias_scale = bias_scale + + def forward(self, x): + """ + An offset is added, treating the last dim of x as the channel dim. + """ + if random.random() < 0.005: + print("bias = ", self.bias) + return x + self.bias * self.bias_scale + + + + class SimpleCausalEncoderLayer(nn.Module): """ This is a simple encoder layer that only sees left-context; it is @@ -816,7 +839,7 @@ class SampleAndPredict(nn.Module): something whose gradient is already (going to be) reversed: specifically, the self-prediction network. """ - x = self.linear1(x) + x = self.linear1(x) * 3 if self.min_prob_ratio > 0.0: x = x + self.class_offsets @@ -829,7 +852,7 @@ class SampleAndPredict(nn.Module): # 'flow_sample'. softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None - if random.random() < 0.01: + if random.random() < 0.05: # Some info that's useful for debug. softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group) logsoftmax_temp = (softmax_temp + 1.0e-20).log() @@ -841,7 +864,8 @@ class SampleAndPredict(nn.Module): print("SampleAndPredict: entropy = ", -negentropy.to('cpu').item(), ", averaged entropy = ", - -global_negentropy.to('cpu').item()) + -global_negentropy.to('cpu').item(), + ", argmax = ", (global_softmax * global_log_softmax).argmax(dim=-1).to('cpu')) x = torch_flow_sampling.flow_sample(x,