Refactoring, and change length_factor from 2.0 to 1.5.

This commit is contained in:
Daniel Povey 2022-11-20 16:34:51 +08:00
parent a52ec3da28
commit cdfbbdded2

View File

@ -544,7 +544,8 @@ class ZipformerEncoder(nn.Module):
final_layerdrop_rate: float = 0.05,
) -> None:
super().__init__()
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15)
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15,
length_factor=1.5)
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
@ -897,11 +898,14 @@ class CompactRelPositionalEncoding(torch.nn.Module):
embed_dim: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length: just a heuristic for initialization.
length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
less weight to small differences of offset near the origin.
"""
def __init__(
self, embed_dim: int,
dropout_rate: float,
max_len: int = 1000
max_len: int = 1000,
length_factor: float = 1.0,
) -> None:
"""Construct a CompactRelPositionalEncoding object."""
super(CompactRelPositionalEncoding, self).__init__()
@ -909,8 +913,12 @@ class CompactRelPositionalEncoding(torch.nn.Module):
assert embed_dim % 2 == 0
self.dropout = torch.nn.Dropout(dropout_rate)
self.pe = None
assert length_factor >= 1.0
self.length_factor = length_factor
self.extend_pe(torch.tensor(0.0).expand(max_len))
def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings."""
if self.pe is not None:
@ -940,18 +948,16 @@ class CompactRelPositionalEncoding(torch.nn.Module):
# is important.
x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length))
# length_factor is chosen so that the FFT can exactly separate points
# close to the origin (T == 0). So this part of the formulation is not really
# heuristic.
length_factor = self.embed_dim / (2.0 * math.pi)
# multiplying length_factor by this heuristic constant should reduce the resolution near to the
# origin, i.e. reduce its ability to separate points near zero.
length_factor *= 2.0
# if self.length_factor == 1.0, then length_scale is chosen so that the
# FFT can exactly separate points close to the origin (T == 0). So this
# part of the formulation is not really heuristic.
# But empirically, for ASR at least, length_factor > 1.0 seems to work better.
length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
# note for machine implementations: if atan is not available, we can use:
# x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
# check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
x_atan = (x_compressed / length_factor).atan() # results between -pi and pi
x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
cosines = (x_atan * freqs).cos()
sines = (x_atan * freqs).sin()
@ -961,14 +967,6 @@ class CompactRelPositionalEncoding(torch.nn.Module):
pe[:, 1::2] = sines
pe[:, -1] = 1.0 # for bias.
# if we have the length_factor correct, the cosines around 0 offset (T in the array)
# should be oscillating in sign like -1, 1, -1; and the sines should all be close to
# zero.
#r = 2
#print("cosines = ", cosines[T-r:T+r,-5:])
#print("sines = ", sines[T-r:T+r,-5:])
self.pe = pe.to(dtype=x.dtype)