mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove the use of random_clamp in conformer.py.
This commit is contained in:
parent
dccff6b893
commit
c5cb52fed1
@ -940,7 +940,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
attn_weights_max=5.0 if self.training else None,
|
||||
)
|
||||
return x, weights
|
||||
|
||||
@ -959,7 +958,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
training: bool = True,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
attn_weights_max: Optional[float] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1111,16 +1109,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# caution: they are really scores at this point.
|
||||
attn_output_weights = torch.matmul(q, k) + pos_weights
|
||||
|
||||
# The following is a soft way of encouraging the attention scores to not be too large;
|
||||
# in training time, once they get outside a certain range, -5.0..5.0 currently, we
|
||||
# randomly either leave them as-is or truncate them to that range.
|
||||
if attn_weights_max is not None:
|
||||
attn_output_weights = random_clamp(attn_output_weights,
|
||||
min=-attn_weights_max,
|
||||
max=attn_weights_max,
|
||||
prob=0.5,
|
||||
reflect=0.1)
|
||||
|
||||
if training and random.random() < 0.1:
|
||||
# This is a harder way of limiting the attention scores to not be too large.
|
||||
# It incurs a penalty if any of them has an absolute value greater than 50.0.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user