mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Refactor how the downsampling is done so that it happens later, but the 1st encoder stack still operates after a subsampling of 2.
This commit is contained in:
parent
0a89f51dc9
commit
d7d5188bd9
@ -138,7 +138,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--zipformer-downsampling-factors",
|
||||
type=str,
|
||||
default="1,2,4",
|
||||
default="2,4,8",
|
||||
help="Downsampling factor for each stack of encoder layers.",
|
||||
)
|
||||
|
||||
@ -428,7 +428,7 @@ def get_params() -> AttributeDict:
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
# parameters for zipformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||
"warm_step": 2000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
@ -443,7 +443,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
return tuple(map(int, s.split(',')))
|
||||
encoder = Zipformer(
|
||||
num_features=params.feature_dim,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
output_downsampling_factor=2,
|
||||
zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors),
|
||||
encoder_dims=to_int_tuple(params.encoder_dims),
|
||||
attention_dim=to_int_tuple(params.attention_dims),
|
||||
|
||||
@ -47,7 +47,6 @@ class Zipformer(EncoderInterface):
|
||||
"""
|
||||
Args:
|
||||
num_features (int): Number of input features
|
||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||
d_model: (int,int): embedding dimension of 2 encoder stacks
|
||||
attention_dim: (int,int): attention dimension of 2 encoder stacks
|
||||
nhead (int, int): number of heads
|
||||
@ -62,12 +61,11 @@ class Zipformer(EncoderInterface):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
subsampling_factor: int = 4,
|
||||
zipformer_subsampling_factor: int = 4,
|
||||
output_downsampling_factor: int = 2,
|
||||
encoder_dims: Tuple[int] = (384, 384),
|
||||
attention_dim: Tuple[int] = (256, 256),
|
||||
encoder_unmasked_dims: Tuple[int] = (256, 256),
|
||||
zipformer_downsampling_factors: Tuple[int] = (1, 2),
|
||||
zipformer_downsampling_factors: Tuple[int] = (2, 4),
|
||||
nhead: Tuple[int] = (8, 8),
|
||||
feedforward_dim: Tuple[int] = (1536, 2048),
|
||||
num_encoder_layers: Tuple[int] = (12, 12),
|
||||
@ -78,23 +76,20 @@ class Zipformer(EncoderInterface):
|
||||
super(Zipformer, self).__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.encoder_unmasked_dims = encoder_unmasked_dims
|
||||
assert 0 < encoder_dims[0] <= encoder_dims[1]
|
||||
self.encoder_dims = encoder_dims
|
||||
self.encoder_unmasked_dims = encoder_unmasked_dims
|
||||
self.zipformer_downsampling_factors = zipformer_downsampling_factors
|
||||
self.output_downsampling_factor = output_downsampling_factor
|
||||
|
||||
for u,d in zip(encoder_unmasked_dims, encoder_dims):
|
||||
assert u <= d
|
||||
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
|
||||
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||
# to the shape (N, T//subsampling_factor, encoder_dims).
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (1) subsampling: T -> T//2
|
||||
# (2) embedding: num_features -> encoder_dims
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0],
|
||||
dropout=dropout)
|
||||
@ -125,10 +120,9 @@ class Zipformer(EncoderInterface):
|
||||
)
|
||||
|
||||
if zipformer_downsampling_factors[i] != 1:
|
||||
assert i > 0, "First zipformer layer cannot use downsampling"
|
||||
encoder = DownsampledZipformerEncoder(
|
||||
encoder,
|
||||
input_dim=encoder_dims[i-1],
|
||||
input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0],
|
||||
output_dim=encoder_dims[i],
|
||||
downsample=zipformer_downsampling_factors[i],
|
||||
)
|
||||
@ -136,6 +130,10 @@ class Zipformer(EncoderInterface):
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
|
||||
|
||||
self.downsample_output = AttentionDownsample(encoder_dims[-1],
|
||||
encoder_dims[-1],
|
||||
downsample=output_downsampling_factor)
|
||||
|
||||
def get_feature_masks(
|
||||
self,
|
||||
x: torch.Tensor) -> List[Union[float, Tensor]]:
|
||||
@ -216,8 +214,7 @@ class Zipformer(EncoderInterface):
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
lengths = (x_lens - 7) // 2
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
|
||||
@ -229,6 +226,10 @@ class Zipformer(EncoderInterface):
|
||||
feature_mask=feature_masks[i],
|
||||
src_key_padding_mask=None if mask is None else mask[...,::ds])
|
||||
|
||||
x = self.downsample_output(x)
|
||||
# class Downsample has this rounding behavior..
|
||||
lengths = (x_lens + 1) // 2
|
||||
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
@ -1468,7 +1469,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||
T' = (T-3)//2 - 2 == (T-7)//2
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
@ -1489,7 +1490,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
Number of channels in. The input shape is (N, T, in_channels).
|
||||
Caution: It requires: T >=7, in_channels >=7
|
||||
out_channels
|
||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
|
||||
Output dim. The output shape is (N, (T-3)//2, out_channels)
|
||||
layer1_channels:
|
||||
Number of channels in layer1
|
||||
layer1_channels:
|
||||
@ -1503,7 +1504,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
padding=(0, 1), # (time, freq)
|
||||
),
|
||||
ActivationBalancer(layer1_channels,
|
||||
channel_dim=1),
|
||||
@ -1513,6 +1514,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0,
|
||||
),
|
||||
ActivationBalancer(layer2_channels,
|
||||
channel_dim=1),
|
||||
@ -1521,13 +1523,13 @@ class Conv2dSubsampling(nn.Module):
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
stride=(1, 2), # (time, freq)
|
||||
),
|
||||
ActivationBalancer(layer3_channels,
|
||||
channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
out_height = (((in_channels - 1) // 2 - 1) // 2)
|
||||
out_height = (((in_channels - 1) // 2) - 1) // 2
|
||||
self.out = ScaledLinear(out_height * layer3_channels, out_channels)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@ -1545,7 +1547,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = self.conv(x)
|
||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).reshape(b, t, c * f))
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user