mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Introduce factor of 2 to more strongly penalize discarded weights.
This commit is contained in:
parent
824d7b4492
commit
5fc0cce553
@ -805,17 +805,13 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
downsampling_factor: factor to downsample by, e.g. 2 or 4. There is no
|
downsampling_factor: factor to downsample by, e.g. 2 or 4. There is no
|
||||||
fundamental reason why this has to be an integer, but we make it so
|
fundamental reason why this has to be an integer, but we make it so
|
||||||
anyway.
|
anyway.
|
||||||
weight_scale: constant scaling factor on the weights, introduced to make fp16 training
|
|
||||||
more stable by reducing gradient magnitudes.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
downsampling_factor: int,
|
downsampling_factor: int):
|
||||||
weight_scale: float = 1.0):
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.weight_scale = weight_scale
|
|
||||||
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
|
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
|
||||||
# score_balancer is just to keep the magnitudes of the scores in
|
# score_balancer is just to keep the magnitudes of the scores in
|
||||||
# a fixed range and keep them balanced around zero, to stop
|
# a fixed range and keep them balanced around zero, to stop
|
||||||
@ -859,7 +855,7 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
sscores, indexes = scores.sort(dim=-1, descending=True)
|
sscores, indexes = scores.sort(dim=-1, descending=True)
|
||||||
|
|
||||||
|
|
||||||
weights = sscores.clamp(min=0.0, max=self.weight_scale)
|
weights = sscores.clamp(min=0.0, max=1.0)
|
||||||
weights = self.copy_weights1(weights)
|
weights = self.copy_weights1(weights)
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
@ -879,9 +875,12 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
if random.random() < 0.01 or __name__ == '__main__':
|
if random.random() < 0.01 or __name__ == '__main__':
|
||||||
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
||||||
|
|
||||||
#weights_discarded = weights_discarded.flip(dims=(1,))
|
|
||||||
|
|
||||||
weights = (weights[:, :seq_len_reduced] - weights_discarded)
|
# we were getting too many discarded weights before introducing this factor, which was
|
||||||
|
# hurting test-mode performance by creating a mismatch.
|
||||||
|
discarded_weights_factor = 2.0
|
||||||
|
|
||||||
|
weights = (weights[:, :seq_len_reduced] - (weights_discarded * discarded_weights_factor)).clamp(min=0.0, max=1.0)
|
||||||
else:
|
else:
|
||||||
# test mode. because the sequence might be short, we keep all nonzero scores;
|
# test mode. because the sequence might be short, we keep all nonzero scores;
|
||||||
# and there is no need for any penalty.
|
# and there is no need for any penalty.
|
||||||
@ -991,9 +990,6 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
# unsqueeze at position 1 so the extra cost relates to the source position.
|
# unsqueeze at position 1 so the extra cost relates to the source position.
|
||||||
attn_offset = attn_offset + (weights + eps).log().unsqueeze(1)
|
attn_offset = attn_offset + (weights + eps).log().unsqueeze(1)
|
||||||
|
|
||||||
if self.weight_scale != 1.0:
|
|
||||||
attn_offset = attn_offset - math.log(self.weight_scale)
|
|
||||||
|
|
||||||
return attn_offset
|
return attn_offset
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user