mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp387' into scaled_adam_exp390
This commit is contained in:
commit
380f773069
@ -728,7 +728,7 @@ class LimitParamValue(torch.autograd.Function):
|
||||
x, = ctx.saved_tensors
|
||||
# where x < ctx.min, ensure all grads are negative (this will tend to make
|
||||
# x more positive).
|
||||
x_grad *= torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0)
|
||||
x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0)
|
||||
# where x > ctx.max, ensure all grads are positive (this will tend to make
|
||||
# x more negative).
|
||||
x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
|
||||
|
||||
@ -553,7 +553,7 @@ class ZipformerEncoder(nn.Module):
|
||||
final_layerdrop_prob: float = 0.05,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.encoder_pos = RelPositionalEncoding(pos_dim, dropout_rate=0.15)
|
||||
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
@ -884,8 +884,22 @@ class SimpleCombiner(torch.nn.Module):
|
||||
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module.
|
||||
class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
"""
|
||||
Relative positional encoding module. This version is "compact" meaning it is able to encode
|
||||
the important information about the relative position in a relatively small number of dimensions.
|
||||
The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001)
|
||||
make very little difference to the embedding. Such differences were potentially important
|
||||
when encoding absolute position, but not important when encoding relative position because there
|
||||
is now no need to compare two large offsets with each other.
|
||||
|
||||
Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval
|
||||
using the atan() function, before doing the fourier transform of that fixed interval. The
|
||||
atan() function would compress the "long tails" too small,
|
||||
making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic
|
||||
function to compress large offsets to a smaller range before applying atan().
|
||||
Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long
|
||||
as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim)
|
||||
|
||||
|
||||
Args:
|
||||
@ -898,8 +912,8 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
dropout_rate: float,
|
||||
max_len: int = 1000
|
||||
) -> None:
|
||||
"""Construct a PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
"""Construct a CompactRelPositionalEncoding object."""
|
||||
super(CompactRelPositionalEncoding, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
assert embed_dim % 2 == 0
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
@ -1199,7 +1213,7 @@ class SelfAttention(nn.Module):
|
||||
Args:
|
||||
embed_dim: the input and output embedding dimension
|
||||
num_heads: the number of attention heads
|
||||
value_dim: the value dimension per head
|
||||
value_head_dim: the value dimension per head
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user