mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp445' into scaled_adam_exp450
This commit is contained in:
commit
a6770657c8
@ -122,7 +122,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--feedforward-dim",
|
"--feedforward-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="1536,1536,2048,1536,1536,1536",
|
default="1280,1280,1280,1792,1280,1280",
|
||||||
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -404,6 +404,8 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
self.nonlin_attention_module = NonlinAttentionModule(embed_dim)
|
self.nonlin_attention_module = NonlinAttentionModule(embed_dim)
|
||||||
|
|
||||||
|
|
||||||
|
self.small_conv_module = SmallConvolutionModule(embed_dim)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(embed_dim,
|
self.conv_module = ConvolutionModule(embed_dim,
|
||||||
cnn_module_kernel)
|
cnn_module_kernel)
|
||||||
|
|
||||||
@ -469,6 +471,8 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
use_self_attn = (random.random() >= dynamic_skip_rate)
|
use_self_attn = (random.random() >= dynamic_skip_rate)
|
||||||
|
|
||||||
|
src = src + self.feed_forward1(src)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||||
attn_weights = self.self_attn_weights(
|
attn_weights = self.self_attn_weights(
|
||||||
@ -482,7 +486,9 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.nonlin_attention_module(src,
|
src = src + self.nonlin_attention_module(src,
|
||||||
attn_weights[0:1])
|
attn_weights[0:1])
|
||||||
|
|
||||||
src = src + self.feed_forward1(src)
|
|
||||||
|
if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate:
|
||||||
|
src = src + self.small_conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
# pooling module
|
# pooling module
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
@ -537,7 +543,8 @@ class ZipformerEncoder(nn.Module):
|
|||||||
final_layerdrop_rate: float = 0.05,
|
final_layerdrop_rate: float = 0.05,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15)
|
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15,
|
||||||
|
length_factor=3.0)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
@ -890,11 +897,14 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
embed_dim: Embedding dimension.
|
embed_dim: Embedding dimension.
|
||||||
dropout_rate: Dropout rate.
|
dropout_rate: Dropout rate.
|
||||||
max_len: Maximum input length: just a heuristic for initialization.
|
max_len: Maximum input length: just a heuristic for initialization.
|
||||||
|
length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
|
||||||
|
less weight to small differences of offset near the origin.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, embed_dim: int,
|
self, embed_dim: int,
|
||||||
dropout_rate: float,
|
dropout_rate: float,
|
||||||
max_len: int = 1000
|
max_len: int = 1000,
|
||||||
|
length_factor: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct a CompactRelPositionalEncoding object."""
|
"""Construct a CompactRelPositionalEncoding object."""
|
||||||
super(CompactRelPositionalEncoding, self).__init__()
|
super(CompactRelPositionalEncoding, self).__init__()
|
||||||
@ -902,8 +912,12 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
assert embed_dim % 2 == 0
|
assert embed_dim % 2 == 0
|
||||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
self.pe = None
|
self.pe = None
|
||||||
|
assert length_factor >= 1.0
|
||||||
|
self.length_factor = length_factor
|
||||||
self.extend_pe(torch.tensor(0.0).expand(max_len))
|
self.extend_pe(torch.tensor(0.0).expand(max_len))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def extend_pe(self, x: Tensor) -> None:
|
def extend_pe(self, x: Tensor) -> None:
|
||||||
"""Reset the positional encodings."""
|
"""Reset the positional encodings."""
|
||||||
if self.pe is not None:
|
if self.pe is not None:
|
||||||
@ -933,15 +947,16 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
# is important.
|
# is important.
|
||||||
x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length))
|
x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length))
|
||||||
|
|
||||||
# length_factor is chosen so that the FFT can exactly separate points
|
# if self.length_factor == 1.0, then length_scale is chosen so that the
|
||||||
# close to the origin (T == 0). So this part of the formulation is not really
|
# FFT can exactly separate points close to the origin (T == 0). So this
|
||||||
# heuristic.
|
# part of the formulation is not really heuristic.
|
||||||
length_factor = self.embed_dim / (2.0 * math.pi) # todo: test this.
|
# But empirically, for ASR at least, length_factor > 1.0 seems to work better.
|
||||||
|
length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
|
||||||
|
|
||||||
# note for machine implementations: if atan is not available, we can use:
|
# note for machine implementations: if atan is not available, we can use:
|
||||||
# x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
|
# x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
|
||||||
# check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
|
# check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
|
||||||
x_atan = (x_compressed / length_factor).atan() # results between -pi and pi
|
x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
|
||||||
|
|
||||||
cosines = (x_atan * freqs).cos()
|
cosines = (x_atan * freqs).cos()
|
||||||
sines = (x_atan * freqs).sin()
|
sines = (x_atan * freqs).sin()
|
||||||
@ -951,14 +966,6 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
pe[:, 1::2] = sines
|
pe[:, 1::2] = sines
|
||||||
pe[:, -1] = 1.0 # for bias.
|
pe[:, -1] = 1.0 # for bias.
|
||||||
|
|
||||||
# if we have the length_factor correct, the cosines around 0 offset (T in the array)
|
|
||||||
# should be oscillating in sign like -1, 1, -1; and the sines should all be close to
|
|
||||||
# zero.
|
|
||||||
#r = 2
|
|
||||||
#print("cosines = ", cosines[T-r:T+r,-5:])
|
|
||||||
#print("sines = ", sines[T-r:T+r,-5:])
|
|
||||||
|
|
||||||
|
|
||||||
self.pe = pe.to(dtype=x.dtype)
|
self.pe = pe.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
@ -1568,6 +1575,80 @@ class ConvolutionModule(nn.Module):
|
|||||||
return x.permute(2, 0, 1)
|
return x.permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class SmallConvolutionModule(nn.Module):
|
||||||
|
"""Part of Zipformer model: a small version of the Convolution module that uses a small kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): The number of channels of conv layers.
|
||||||
|
kernel_size (int): Kernerl size of conv layers.
|
||||||
|
bias (bool): Whether to use bias in conv layers (default=True).
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, channels: int, hidden_dim: int = 256,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
hidden_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.deriv_balancer = ActivationBalancer(
|
||||||
|
hidden_dim, channel_dim=1,
|
||||||
|
min_positive=0.05, max_positive=1.0,
|
||||||
|
max_abs=20.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.activation = DoubleSwish()
|
||||||
|
|
||||||
|
self.conv2 = ScaledConv1d(
|
||||||
|
hidden_dim,
|
||||||
|
channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=True,
|
||||||
|
initial_scale=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: Tensor,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Compute convolution module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor (#time, batch, channels).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional):
|
||||||
|
(batch, #time), contains bool in masked positions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (#time, batch, channels).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# exchange the temporal dimension and the feature dimension
|
||||||
|
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||||
|
|
||||||
|
x = self.conv1(x) # (batch, hidden_dim, time)
|
||||||
|
|
||||||
|
x = self.deriv_balancer(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
|
||||||
|
if src_key_padding_mask is not None:
|
||||||
|
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||||
|
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
return x.permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2dSubsampling(nn.Module):
|
class Conv2dSubsampling(nn.Module):
|
||||||
"""Convolutional 2D subsampling (to 1/2 length).
|
"""Convolutional 2D subsampling (to 1/2 length).
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user