mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-19 06:53:10 +00:00
Some configuration changes, change how prob_boost works
This commit is contained in:
parent
656de090bd
commit
c4cc952265
@ -190,7 +190,7 @@ class BidirectionalConformer(nn.Module):
|
||||
num_decoder_layers: int = 6,
|
||||
num_reverse_encoder_layers: int = 4,
|
||||
num_reverse_decoder_layers: int = 4,
|
||||
num_self_predictor_layers: int = 2,
|
||||
fake_token_seq_length: int = 30,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
is_bpe: bool = False,
|
||||
@ -308,14 +308,8 @@ class BidirectionalConformer(nn.Module):
|
||||
# It just accepts the output of self.reverse_decoder as
|
||||
# the input to its prediction mechanism.
|
||||
|
||||
if num_self_predictor_layers > 0:
|
||||
encoder_layer = SimpleCausalEncoderLayer(d_model,
|
||||
dropout=dropout)
|
||||
final_linear = nn.Linear(d_model, d_model, bias=False)
|
||||
self.self_predictor_encoder = nn.Sequential(*[copy.deepcopy(encoder_layer)
|
||||
for _ in range(num_self_predictor_layers)],
|
||||
final_linear,
|
||||
FastOffsetLayer(d_model))
|
||||
if fake_token_seq_length > 0:
|
||||
self.fake_embedding = torch.nn.Parameter(torch.randn(fake_token_seq_length, d_model) * (d_model ** -0.5))
|
||||
|
||||
|
||||
self.sample_and_predict = SampleAndPredict(
|
||||
@ -365,7 +359,6 @@ class BidirectionalConformer(nn.Module):
|
||||
shifted by one so that positive_embed_shifted[t] == positive_embed[t-1], as in:
|
||||
(T, N, E) = positive_embed.shape
|
||||
positive_embed_shifted = torch.cat((torch.zeros(1, N, E), positive_embed[:-1,:,:]), dim=0)
|
||||
|
||||
"""
|
||||
(sampled, softmax, positive_embed, negative_embed) = self.sample_and_predict(memory)
|
||||
|
||||
@ -514,13 +507,33 @@ class BidirectionalConformer(nn.Module):
|
||||
# layer uses left-padding only, making it causal, so the mask
|
||||
# is redundant (it wouldn't affect any of the
|
||||
# outputs we care about).
|
||||
predictor = self.self_predictor_encoder(negative_embed_shifted)
|
||||
|
||||
prob = self.sample_and_predict.compute_prob(predictor,
|
||||
sampled, softmax,
|
||||
memory_key_padding_mask,
|
||||
reverse_grad=True)
|
||||
return prob
|
||||
(S, E) = self.fake_embedding.shape
|
||||
(T, N, E2) = negative_embed_shifted.shape
|
||||
assert E == E2
|
||||
embedding_scale = E ** 0.5 # for better learning dynamics
|
||||
token_memory = (self.fake_embedding * embedding_scale).unsqueeze(1).expand(S, N, E)
|
||||
|
||||
tokens_key_padding_mask = None
|
||||
|
||||
# the targets, here, are the hidden discrete symbols we are predicting
|
||||
tgt_mask = generate_square_subsequent_mask(T, device=negative_embed_shifted.device)
|
||||
|
||||
hidden_predictor = self.reverse_decoder(
|
||||
tgt=negative_embed_shifted,
|
||||
memory=token_memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_key_padding_mask=tokens_key_padding_mask)
|
||||
|
||||
total_prob = self.sample_and_predict.compute_prob(
|
||||
hidden_predictor,
|
||||
sampled,
|
||||
softmax,
|
||||
memory_key_padding_mask,
|
||||
reverse_grad=True)
|
||||
|
||||
# TODO: consider using a label-smoothed loss.
|
||||
return total_prob
|
||||
|
||||
|
||||
def reverse_decoder_forward(
|
||||
@ -748,8 +761,9 @@ class SampleAndPredict(nn.Module):
|
||||
tot_classes: int,
|
||||
num_groups: int,
|
||||
interp_prob: float = 1.0,
|
||||
straight_through_scale: float = 0.0,
|
||||
min_prob_ratio: float = 0.1,
|
||||
straight_through_scale: float = 1.0,
|
||||
min_prob_ratio: float = 0.3,
|
||||
max_prob_ratio: float = 3.0,
|
||||
):
|
||||
super(SampleAndPredict, self).__init__()
|
||||
|
||||
@ -759,11 +773,12 @@ class SampleAndPredict(nn.Module):
|
||||
self.interp_prob = interp_prob
|
||||
self.straight_through_scale = straight_through_scale
|
||||
self.min_prob_ratio = min_prob_ratio
|
||||
self.max_prob_ratio = max_prob_ratio
|
||||
self.tot_classes = tot_classes
|
||||
self.classes_per_group = tot_classes // num_groups
|
||||
|
||||
# prob_boost relates to the min_prob_ratio setting. It's not configurable for now.
|
||||
self.prob_boost = 1.0e-03
|
||||
# prob_boost relates to the min_prob_ratio and max_prob_ratio settings. It's not configurable for now.
|
||||
self.prob_boost = 1.0e-02
|
||||
|
||||
# class_probs is a rolling mean of the output of the sampling operation.
|
||||
# When any element of it gets below self.min_prob_ratio / self.classes_per_group,
|
||||
@ -807,6 +822,7 @@ class SampleAndPredict(nn.Module):
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
torch.nn.init.xavier_uniform_(self.linear1.weight, gain=5)
|
||||
if hasattr(self, 'pred_cross'):
|
||||
torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5))
|
||||
|
||||
@ -839,9 +855,9 @@ class SampleAndPredict(nn.Module):
|
||||
something whose gradient is already (going to be) reversed:
|
||||
specifically, the self-prediction network.
|
||||
"""
|
||||
x = self.linear1(x) * 3
|
||||
x = self.linear1(x)
|
||||
|
||||
if self.min_prob_ratio > 0.0:
|
||||
if self.min_prob_ratio != 0.0 or self.max_prob_ratio != 0.0:
|
||||
x = x + self.class_offsets
|
||||
|
||||
(S, N, tot_classes) = x.shape
|
||||
@ -852,7 +868,7 @@ class SampleAndPredict(nn.Module):
|
||||
# 'flow_sample'.
|
||||
softmax = x.softmax(dim=3).reshape(S, N, tot_classes) if need_softmax else None
|
||||
|
||||
if random.random() < 0.05:
|
||||
if random.random() < 0.01:
|
||||
# Some info that's useful for debug.
|
||||
softmax_temp = softmax.reshape(S, N, self.num_groups, self.classes_per_group)
|
||||
logsoftmax_temp = (softmax_temp + 1.0e-20).log()
|
||||
@ -877,12 +893,14 @@ class SampleAndPredict(nn.Module):
|
||||
|
||||
sampled = x
|
||||
|
||||
if self.training and self.min_prob_ratio > 0.0:
|
||||
if self.training and (self.min_prob_ratio != 0.0 or self.max_prob_ratio != 0.0):
|
||||
mean_class_probs = torch.mean(x.detach(), dim=(0,1))
|
||||
self.class_probs = (self.class_probs * self.class_probs_decay +
|
||||
mean_class_probs * (1.0 - self.class_probs_decay))
|
||||
prob_floor = self.min_prob_ratio / self.classes_per_group
|
||||
self.class_offsets += (self.class_probs < prob_floor) * self.prob_boost
|
||||
prob_ceil = self.max_prob_ratio / self.classes_per_group
|
||||
self.class_offsets += (((self.class_probs < prob_floor) * self.prob_boost) -
|
||||
((self.class_probs > prob_ceil)) * self.prob_boost)
|
||||
|
||||
|
||||
positive_embed = self.post_layer_norm(self.linear2(sampled))
|
||||
|
@ -155,7 +155,7 @@ def get_params() -> AttributeDict:
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_1"),
|
||||
"exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_2"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4, # can't be changed
|
||||
@ -171,8 +171,8 @@ def get_params() -> AttributeDict:
|
||||
"reduction": "sum",
|
||||
"use_double_scores": True,
|
||||
"accum_grad": 1,
|
||||
"att_scale": 0.7,
|
||||
"reverse_att_scale": 0.01, # ctc_scale == 1.0 - att_scale - reverse_att_scale
|
||||
"att_scale": 0.6,
|
||||
"reverse_att_scale": 0.1, # ctc_scale == 1.0 - att_scale - reverse_att_scale
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"num_trunk_encoder_layers": 12,
|
||||
@ -180,7 +180,6 @@ def get_params() -> AttributeDict:
|
||||
"num_decoder_layers": 6,
|
||||
"num_reverse_encoder_layers": 4,
|
||||
"num_reverse_decoder_layers": 4,
|
||||
"num_self_predictor_layers": 2,
|
||||
"discretization_tot_classes": 512,
|
||||
"discretization_num_groups": 8,
|
||||
"is_bpe": True,
|
||||
@ -679,7 +678,6 @@ def run(rank, world_size, args):
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
num_reverse_encoder_layers=params.num_reverse_encoder_layers,
|
||||
num_reverse_decoder_layers=params.num_reverse_decoder_layers,
|
||||
num_self_predictor_layers=params.num_self_predictor_layers,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
is_bpe=params.is_bpe,
|
||||
discretization_tot_classes=params.discretization_tot_classes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user