mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Cosmetic improvements
This commit is contained in:
parent
46bd93b792
commit
d1df919547
@ -553,7 +553,7 @@ class ZipformerEncoder(nn.Module):
|
||||
final_layerdrop_prob: float = 0.05,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.encoder_pos = RelPositionalEncoding(pos_dim, dropout_rate=0.15)
|
||||
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
@ -884,8 +884,22 @@ class SimpleCombiner(torch.nn.Module):
|
||||
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module.
|
||||
class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
"""
|
||||
Relative positional encoding module. This version is "compact" meaning it is able to encode
|
||||
the important information about the relative position in a relatively small number of dimensions.
|
||||
The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001)
|
||||
make very little difference to the embedding. Such differences were potentially important
|
||||
when encoding absolute position, but not important when encoding relative position because there
|
||||
is now no need to compare two large offsets with each other.
|
||||
|
||||
Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval
|
||||
using the atan() function, before doing the fourier transform of that fixed interval. The
|
||||
atan() function would compress the "long tails" too small,
|
||||
making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic
|
||||
function to compress large offsets to a smaller range before applying atan().
|
||||
Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long
|
||||
as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim)
|
||||
|
||||
|
||||
Args:
|
||||
@ -898,8 +912,8 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
dropout_rate: float,
|
||||
max_len: int = 1000
|
||||
) -> None:
|
||||
"""Construct a PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
"""Construct a CompactRelPositionalEncoding object."""
|
||||
super(CompactRelPositionalEncoding, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
assert embed_dim % 2 == 0
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user