Some configuration changes, change how prob_boost works

This commit is contained in:
Daniel Povey 2021-09-21 12:06:41 +08:00
parent 656de090bd
commit c4cc952265
2 changed files with 46 additions and 30 deletions

View File

@ -190,7 +190,7 @@ class BidirectionalConformer(nn.Module):
num_decoder_layers: int = 6, num_decoder_layers: int = 6,
num_reverse_encoder_layers: int = 4, num_reverse_encoder_layers: int = 4,
num_reverse_decoder_layers: int = 4, num_reverse_decoder_layers: int = 4,
num_self_predictor_layers: int = 2, fake_token_seq_length: int = 30,
dropout: float = 0.1, dropout: float = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
is_bpe: bool = False, is_bpe: bool = False,
@ -308,14 +308,8 @@ class BidirectionalConformer(nn.Module):
# It just accepts the output of self.reverse_decoder as # It just accepts the output of self.reverse_decoder as
# the input to its prediction mechanism. # the input to its prediction mechanism.
if num_self_predictor_layers > 0: if fake_token_seq_length > 0:
encoder_layer = SimpleCausalEncoderLayer(d_model, self.fake_embedding = torch.nn.Parameter(torch.randn(fake_token_seq_length, d_model) * (d_model ** -0.5))
dropout=dropout)
final_linear = nn.Linear(d_model, d_model, bias=False)
self.self_predictor_encoder = nn.Sequential(*[copy.deepcopy(encoder_layer)
for _ in range(num_self_predictor_layers)],
final_linear,
FastOffsetLayer(d_model))
self.sample_and_predict = SampleAndPredict( self.sample_and_predict = SampleAndPredict(
@ -365,7 +359,6 @@ class BidirectionalConformer(nn.Module):
shifted by one so that positive_embed_shifted[t] == positive_embed[t-1], as in: shifted by one so that positive_embed_shifted[t] == positive_embed[t-1], as in:
(T, N, E) = positive_embed.shape (T, N, E) = positive_embed.shape
positive_embed_shifted = torch.cat((torch.zeros(1, N, E), positive_embed[:-1,:,:]), dim=0) positive_embed_shifted = torch.cat((torch.zeros(1, N, E), positive_embed[:-1,:,:]), dim=0)
""" """
(sampled, softmax, positive_embed, negative_embed) = self.sample_and_predict(memory) (sampled, softmax, positive_embed, negative_embed) = self.sample_and_predict(memory)
@ -514,13 +507,33 @@ class BidirectionalConformer(nn.Module):
# layer uses left-padding only, making it causal, so the mask # layer uses left-padding only, making it causal, so the mask
# is redundant (it wouldn't affect any of the # is redundant (it wouldn't affect any of the
# outputs we care about). # outputs we care about).
predictor = self.self_predictor_encoder(negative_embed_shifted)
prob = self.sample_and_predict.compute_prob(predictor, (S, E) = self.fake_embedding.shape
sampled, softmax, (T, N, E2) = negative_embed_shifted.shape
memory_key_padding_mask, assert E == E2
reverse_grad=True) embedding_scale = E ** 0.5 # for better learning dynamics
return prob token_memory = (self.fake_embedding * embedding_scale).unsqueeze(1).expand(S, N, E)
tokens_key_padding_mask = None
# the targets, here, are the hidden discrete symbols we are predicting
tgt_mask = generate_square_subsequent_mask(T, device=negative_embed_shifted.device)
hidden_predictor = self.reverse_decoder(
tgt=negative_embed_shifted,
memory=token_memory,
tgt_mask=tgt_mask,
memory_key_padding_mask=tokens_key_padding_mask)
total_prob = self.sample_and_predict.compute_prob(
hidden_predictor,
sampled,
softmax,
memory_key_padding_mask,
reverse_grad=True)
# TODO: consider using a label-smoothed loss.
return total_prob
def reverse_decoder_forward( def reverse_decoder_forward(
@ -748,8 +761,9 @@ class SampleAndPredict(nn.Module):
tot_classes: int, tot_classes: int,
num_groups: int, num_groups: int,
interp_prob: float = 1.0, interp_prob: float = 1.0,
straight_through_scale: float = 0.0, straight_through_scale: float = 1.0,
min_prob_ratio: float = 0.1, min_prob_ratio: float = 0.3,
max_prob_ratio: float = 3.0,
): ):
super(SampleAndPredict, self).__init__() super(SampleAndPredict, self).__init__()
@ -759,11 +773,12 @@ class SampleAndPredict(nn.Module):
self.interp_prob = interp_prob self.interp_prob = interp_prob
self.straight_through_scale = straight_through_scale self.straight_through_scale = straight_through_scale
self.min_prob_ratio = min_prob_ratio self.min_prob_ratio = min_prob_ratio
self.max_prob_ratio = max_prob_ratio
self.tot_classes = tot_classes self.tot_classes = tot_classes
self.classes_per_group = tot_classes // num_groups self.classes_per_group = tot_classes // num_groups
# prob_boost relates to the min_prob_ratio setting. It's not configurable for now. # prob_boost relates to the min_prob_ratio and max_prob_ratio settings. It's not configurable for now.
self.prob_boost = 1.0e-03 self.prob_boost = 1.0e-02
# class_probs is a rolling mean of the output of the sampling operation. # class_probs is a rolling mean of the output of the sampling operation.
# When any element of it gets below self.min_prob_ratio / self.classes_per_group, # When any element of it gets below self.min_prob_ratio / self.classes_per_group,
@ -807,6 +822,7 @@ class SampleAndPredict(nn.Module):
self._reset_parameters() self._reset_parameters()
def _reset_parameters(self): def _reset_parameters(self):
torch.nn.init.xavier_uniform_(self.linear1.weight, gain=5)
if hasattr(self, 'pred_cross'): if hasattr(self, 'pred_cross'):
torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5))
@ -839,9 +855,9 @@ class SampleAndPredict(nn.Module):
something whose gradient is already (going to be) reversed: something whose gradient is already (going to be) reversed:
specifically, the self-prediction network. specifically, the self-prediction network.
""" """
x = self.linear1(x) * 3 x = self.linear1(x)
if self.min_prob_ratio > 0.0: if self.min_prob_ratio != 0.0 or self.max_prob_ratio != 0.0:
x = x + self.class_offsets x = x + self.class_offsets
(S, N, tot_classes) = x.shape (S, N, tot_classes) = x.shape
@ -852,7 +868,7 @@ class SampleAndPredict(nn.Module):
# 'flow_sample'. # 'flow_sample'.
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
if random.random() < 0.05: if random.random() < 0.01:
# Some info that's useful for debug. # Some info that's useful for debug.
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group) softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
logsoftmax_temp = (softmax_temp + 1.0e-20).log() logsoftmax_temp = (softmax_temp + 1.0e-20).log()
@ -877,12 +893,14 @@ class SampleAndPredict(nn.Module):
sampled = x sampled = x
if self.training and self.min_prob_ratio > 0.0: if self.training and (self.min_prob_ratio != 0.0 or self.max_prob_ratio != 0.0):
mean_class_probs = torch.mean(x.detach(), dim=(0,1)) mean_class_probs = torch.mean(x.detach(), dim=(0,1))
self.class_probs = (self.class_probs * self.class_probs_decay + self.class_probs = (self.class_probs * self.class_probs_decay +
mean_class_probs * (1.0 - self.class_probs_decay)) mean_class_probs * (1.0 - self.class_probs_decay))
prob_floor = self.min_prob_ratio / self.classes_per_group prob_floor = self.min_prob_ratio / self.classes_per_group
self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost prob_ceil = self.max_prob_ratio / self.classes_per_group
self.class_offsets += (((self.class_probs < prob_floor) * self.prob_boost) -
((self.class_probs > prob_ceil)) * self.prob_boost)
positive_embed = self.post_layer_norm(self.linear2(sampled)) positive_embed = self.post_layer_norm(self.linear2(sampled))

View File

@ -155,7 +155,7 @@ def get_params() -> AttributeDict:
""" """
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_1"), "exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_2"),
"lang_dir": Path("data/lang_bpe"), "lang_dir": Path("data/lang_bpe"),
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, # can't be changed "subsampling_factor": 4, # can't be changed
@ -171,8 +171,8 @@ def get_params() -> AttributeDict:
"reduction": "sum", "reduction": "sum",
"use_double_scores": True, "use_double_scores": True,
"accum_grad": 1, "accum_grad": 1,
"att_scale": 0.7, "att_scale": 0.6,
"reverse_att_scale": 0.01, # ctc_scale == 1.0 - att_scale - reverse_att_scale "reverse_att_scale": 0.1, # ctc_scale == 1.0 - att_scale - reverse_att_scale
"attention_dim": 512, "attention_dim": 512,
"nhead": 8, "nhead": 8,
"num_trunk_encoder_layers": 12, "num_trunk_encoder_layers": 12,
@ -180,7 +180,6 @@ def get_params() -> AttributeDict:
"num_decoder_layers": 6, "num_decoder_layers": 6,
"num_reverse_encoder_layers": 4, "num_reverse_encoder_layers": 4,
"num_reverse_decoder_layers": 4, "num_reverse_decoder_layers": 4,
"num_self_predictor_layers": 2,
"discretization_tot_classes": 512, "discretization_tot_classes": 512,
"discretization_num_groups": 8, "discretization_num_groups": 8,
"is_bpe": True, "is_bpe": True,
@ -679,7 +678,6 @@ def run(rank, world_size, args):
num_decoder_layers=params.num_decoder_layers, num_decoder_layers=params.num_decoder_layers,
num_reverse_encoder_layers=params.num_reverse_encoder_layers, num_reverse_encoder_layers=params.num_reverse_encoder_layers,
num_reverse_decoder_layers=params.num_reverse_decoder_layers, num_reverse_decoder_layers=params.num_reverse_decoder_layers,
num_self_predictor_layers=params.num_self_predictor_layers,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
is_bpe=params.is_bpe, is_bpe=params.is_bpe,
discretization_tot_classes=params.discretization_tot_classes, discretization_tot_classes=params.discretization_tot_classes,