Merge branch 'scaled_adam_exp445' into scaled_adam_exp450

This commit is contained in:
Daniel Povey 2022-11-21 14:29:50 +08:00
commit a6770657c8
2 changed files with 98 additions and 17 deletions

View File

@ -122,7 +122,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--feedforward-dim", "--feedforward-dim",
type=str, type=str,
default="1536,1536,2048,1536,1536,1536", default="1280,1280,1280,1792,1280,1280",
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
) )

View File

@ -404,6 +404,8 @@ class ZipformerEncoderLayer(nn.Module):
self.nonlin_attention_module = NonlinAttentionModule(embed_dim) self.nonlin_attention_module = NonlinAttentionModule(embed_dim)
self.small_conv_module = SmallConvolutionModule(embed_dim)
self.conv_module = ConvolutionModule(embed_dim, self.conv_module = ConvolutionModule(embed_dim,
cnn_module_kernel) cnn_module_kernel)
@ -469,6 +471,8 @@ class ZipformerEncoderLayer(nn.Module):
# multi-headed self-attention module # multi-headed self-attention module
use_self_attn = (random.random() >= dynamic_skip_rate) use_self_attn = (random.random() >= dynamic_skip_rate)
src = src + self.feed_forward1(src)
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or use_self_attn:
# attn_weights: (num_heads, batch_size, seq_len, seq_len) # attn_weights: (num_heads, batch_size, seq_len, seq_len)
attn_weights = self.self_attn_weights( attn_weights = self.self_attn_weights(
@ -482,7 +486,9 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.nonlin_attention_module(src, src = src + self.nonlin_attention_module(src,
attn_weights[0:1]) attn_weights[0:1])
src = src + self.feed_forward1(src)
if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate:
src = src + self.small_conv_module(src, src_key_padding_mask=src_key_padding_mask)
# pooling module # pooling module
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or use_self_attn:
@ -537,7 +543,8 @@ class ZipformerEncoder(nn.Module):
final_layerdrop_rate: float = 0.05, final_layerdrop_rate: float = 0.05,
) -> None: ) -> None:
super().__init__() super().__init__()
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15) self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15,
length_factor=3.0)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)] [copy.deepcopy(encoder_layer) for i in range(num_layers)]
@ -890,11 +897,14 @@ class CompactRelPositionalEncoding(torch.nn.Module):
embed_dim: Embedding dimension. embed_dim: Embedding dimension.
dropout_rate: Dropout rate. dropout_rate: Dropout rate.
max_len: Maximum input length: just a heuristic for initialization. max_len: Maximum input length: just a heuristic for initialization.
length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
less weight to small differences of offset near the origin.
""" """
def __init__( def __init__(
self, embed_dim: int, self, embed_dim: int,
dropout_rate: float, dropout_rate: float,
max_len: int = 1000 max_len: int = 1000,
length_factor: float = 1.0,
) -> None: ) -> None:
"""Construct a CompactRelPositionalEncoding object.""" """Construct a CompactRelPositionalEncoding object."""
super(CompactRelPositionalEncoding, self).__init__() super(CompactRelPositionalEncoding, self).__init__()
@ -902,8 +912,12 @@ class CompactRelPositionalEncoding(torch.nn.Module):
assert embed_dim % 2 == 0 assert embed_dim % 2 == 0
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = torch.nn.Dropout(dropout_rate)
self.pe = None self.pe = None
assert length_factor >= 1.0
self.length_factor = length_factor
self.extend_pe(torch.tensor(0.0).expand(max_len)) self.extend_pe(torch.tensor(0.0).expand(max_len))
def extend_pe(self, x: Tensor) -> None: def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings.""" """Reset the positional encodings."""
if self.pe is not None: if self.pe is not None:
@ -933,15 +947,16 @@ class CompactRelPositionalEncoding(torch.nn.Module):
# is important. # is important.
x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length)) x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length))
# length_factor is chosen so that the FFT can exactly separate points # if self.length_factor == 1.0, then length_scale is chosen so that the
# close to the origin (T == 0). So this part of the formulation is not really # FFT can exactly separate points close to the origin (T == 0). So this
# heuristic. # part of the formulation is not really heuristic.
length_factor = self.embed_dim / (2.0 * math.pi) # todo: test this. # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
# note for machine implementations: if atan is not available, we can use: # note for machine implementations: if atan is not available, we can use:
# x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
# check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
x_atan = (x_compressed / length_factor).atan() # results between -pi and pi x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
cosines = (x_atan * freqs).cos() cosines = (x_atan * freqs).cos()
sines = (x_atan * freqs).sin() sines = (x_atan * freqs).sin()
@ -951,14 +966,6 @@ class CompactRelPositionalEncoding(torch.nn.Module):
pe[:, 1::2] = sines pe[:, 1::2] = sines
pe[:, -1] = 1.0 # for bias. pe[:, -1] = 1.0 # for bias.
# if we have the length_factor correct, the cosines around 0 offset (T in the array)
# should be oscillating in sign like -1, 1, -1; and the sines should all be close to
# zero.
#r = 2
#print("cosines = ", cosines[T-r:T+r,-5:])
#print("sines = ", sines[T-r:T+r,-5:])
self.pe = pe.to(dtype=x.dtype) self.pe = pe.to(dtype=x.dtype)
@ -1568,6 +1575,80 @@ class ConvolutionModule(nn.Module):
return x.permute(2, 0, 1) return x.permute(2, 0, 1)
class SmallConvolutionModule(nn.Module):
"""Part of Zipformer model: a small version of the Convolution module that uses a small kernel.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self, channels: int, hidden_dim: int = 256,
) -> None:
super().__init__()
self.conv1 = nn.Conv1d(
channels,
hidden_dim,
kernel_size=3,
stride=1,
padding=1,
bias=True,
)
self.deriv_balancer = ActivationBalancer(
hidden_dim, channel_dim=1,
min_positive=0.05, max_positive=1.0,
max_abs=20.0,
)
self.activation = DoubleSwish()
self.conv2 = ScaledConv1d(
hidden_dim,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=True,
initial_scale=0.05,
)
def forward(self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional):
(batch, #time), contains bool in masked positions.
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
x = self.conv1(x) # (batch, hidden_dim, time)
x = self.deriv_balancer(x)
x = self.activation(x)
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.conv2(x)
return x.permute(2, 0, 1)
class Conv2dSubsampling(nn.Module): class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length). """Convolutional 2D subsampling (to 1/2 length).