mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement weight_scale, set weight_scale=10
This commit is contained in:
parent
f6c7392430
commit
d2c198c072
@ -805,13 +805,17 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
downsampling_factor: factor to downsample by, e.g. 2 or 4. There is no
|
||||
fundamental reason why this has to be an integer, but we make it so
|
||||
anyway.
|
||||
weight_scale: constant scaling factor on the weights, introduced to make fp16 training
|
||||
more stable by reducing gradient magnitudes.
|
||||
"""
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
downsampling_factor: int):
|
||||
downsampling_factor: int,
|
||||
weight_scale: float = 10.0):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.weight_scale = weight_scale
|
||||
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
|
||||
# score_balancer is just to keep the magnitudes of the scores in
|
||||
# a fixed range and keep them balanced around zero, to stop
|
||||
@ -855,7 +859,7 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
sscores, indexes = scores.sort(dim=-1, descending=True)
|
||||
|
||||
|
||||
weights = sscores.clamp(min=0.0, max=1.0)
|
||||
weights = sscores.clamp(min=0.0, max=self.weight_scale)
|
||||
weights = self.copy_weights1(weights)
|
||||
|
||||
if self.training:
|
||||
@ -987,6 +991,9 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
# unsqueeze at position 1 so the extra cost relates to the source position.
|
||||
attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1)
|
||||
|
||||
if self.weight_scale != 1.0:
|
||||
attn_offset = attn_offset - math.log(self.weight_scale)
|
||||
|
||||
return attn_offset
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user