mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add SmallConvModule; decrease feedforward dims to keep about same num params.
This commit is contained in:
parent
f7c99ed1d1
commit
8a095c1cd1
@ -122,7 +122,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--feedforward-dim",
|
||||
type=str,
|
||||
default="1536,1536,2048,1536,1536,1536",
|
||||
default="1280,1280,1536,1280,1280,1280",
|
||||
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
||||
)
|
||||
|
||||
|
||||
@ -405,6 +405,8 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
self.nonlin_attention_module = NonlinAttentionModule(embed_dim)
|
||||
|
||||
|
||||
self.small_conv_module = SmallConvolutionModule(embed_dim)
|
||||
|
||||
self.conv_module = ConvolutionModule(embed_dim,
|
||||
cnn_module_kernel)
|
||||
|
||||
@ -483,6 +485,10 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
src = src + self.nonlin_attention_module(src,
|
||||
attn_weights[0:1])
|
||||
|
||||
|
||||
if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate:
|
||||
src = src + self.small_conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
src = src + self.feed_forward1(src)
|
||||
|
||||
# pooling module
|
||||
@ -1569,6 +1575,80 @@ class ConvolutionModule(nn.Module):
|
||||
return x.permute(2, 0, 1)
|
||||
|
||||
|
||||
class SmallConvolutionModule(nn.Module):
|
||||
"""Part of Zipformer model: a small version of the Convolution module that uses a small kernel.
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
bias (bool): Whether to use bias in conv layers (default=True).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, hidden_dim: int = 256,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = nn.Conv1d(
|
||||
channels,
|
||||
hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.deriv_balancer = ActivationBalancer(
|
||||
hidden_dim, channel_dim=1,
|
||||
min_positive=0.05, max_positive=1.0,
|
||||
max_abs=20.0,
|
||||
)
|
||||
|
||||
self.activation = DoubleSwish()
|
||||
|
||||
self.conv2 = ScaledConv1d(
|
||||
hidden_dim,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True,
|
||||
initial_scale=0.05,
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x: Input tensor (#time, batch, channels).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional):
|
||||
(batch, #time), contains bool in masked positions.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (#time, batch, channels).
|
||||
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
|
||||
x = self.conv1(x) # (batch, hidden_dim, time)
|
||||
|
||||
x = self.deriv_balancer(x)
|
||||
x = self.activation(x)
|
||||
|
||||
if src_key_padding_mask is not None:
|
||||
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
return x.permute(2, 0, 1)
|
||||
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/2 length).
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user