separate Conv2dSubsampling from Zipformer

This commit is contained in:
yaozengwei 2023-04-27 10:11:47 +08:00
parent 0ec31c84da
commit 55a1abc9da
5 changed files with 364 additions and 321 deletions

View File

@ -127,6 +127,7 @@ from icefall.checkpoint import (
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
@ -365,9 +366,15 @@ def decode_one_batch(
value=LOG_EPS,
)
x, x_lens = model.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
hyps = []

View File

@ -19,14 +19,11 @@ import k2
import torch
import torch.nn as nn
import random
import warnings
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
from scaling import (
penalize_abs_values_gt,
ScaledLinear
)
from icefall.utils import add_sos, make_pad_mask
from scaling import penalize_abs_values_gt, ScaledLinear
class Transducer(nn.Module):
@ -36,6 +33,7 @@ class Transducer(nn.Module):
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
@ -46,6 +44,10 @@ class Transducer(nn.Module):
):
"""
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
@ -64,18 +66,22 @@ class Transducer(nn.Module):
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder_embed = encoder_embed
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_scale=0.25,
encoder_dim,
vocab_size,
initial_scale=0.25,
)
self.simple_lm_proj = ScaledLinear(
decoder_dim, vocab_size, initial_scale=0.25,
decoder_dim,
vocab_size,
initial_scale=0.25,
)
def forward(
self,
x: torch.Tensor,
@ -119,7 +125,15 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(x_lens > 0)
@ -142,7 +156,9 @@ class Transducer(nn.Module):
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
@ -150,9 +166,9 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
#if self.training and random.random() < 0.25:
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
#if self.training and random.random() < 0.25:
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):

View File

@ -0,0 +1,280 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import warnings
import torch
from torch import Tensor, nn
from scaling import (
Balancer,
BiasNorm,
Dropout3,
FloatLike,
Optional,
ScaledConv2d,
ScaleGrad,
ScheduledFloat,
SwooshL,
SwooshR,
Whiten,
)
class ConvNeXt(nn.Module):
"""
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
"""
def __init__(
self,
channels: int,
hidden_ratio: int = 3,
kernel_size: Tuple[int, int] = (7, 7),
layerdrop_rate: FloatLike = None,
):
super().__init__()
padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
hidden_channels = channels * hidden_ratio
if layerdrop_rate is None:
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
self.layerdrop_rate = layerdrop_rate
self.depthwise_conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
groups=channels,
kernel_size=kernel_size,
padding=padding,
)
self.pointwise_conv1 = nn.Conv2d(
in_channels=channels, out_channels=hidden_channels, kernel_size=1
)
self.hidden_balancer = Balancer(
hidden_channels,
channel_dim=1,
min_positive=0.3,
max_positive=1.0,
min_abs=0.75,
max_abs=5.0,
)
self.activation = SwooshL()
self.pointwise_conv2 = ScaledConv2d(
in_channels=hidden_channels,
out_channels=channels,
kernel_size=1,
initial_scale=0.01,
)
self.out_balancer = Balancer(
channels,
channel_dim=1,
min_positive=0.4,
max_positive=0.6,
min_abs=1.0,
max_abs=6.0,
)
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=5.0,
prob=(0.025, 0.25),
grad_scale=0.01,
)
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not self.training:
return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate)
if layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = (
torch.rand(
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
)
> layerdrop_rate
)
else:
mask = None
# turns out this caching idea does not work with --world-size > 1
# return caching_eval(self.forward_internal, x, mask)
return self.forward_internal(x, mask)
def forward_internal(
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
) -> Tensor:
"""
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
The returned value has the same shape as x.
"""
bypass = x
x = self.depthwise_conv(x)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
if layer_skip_mask is not None:
x = x * layer_skip_mask
x = bypass + x
x = self.out_balancer(x)
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
x = self.out_whiten(x)
x = x.transpose(1, 3) # (N, C, H, W)
return x
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = (T-3)//2 - 2 == (T-7)//2
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
dropout: FloatLike = 0.1,
) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >=7, in_channels >=7
out_channels
Output dim. The output shape is (N, (T-3)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
bottleneck:
bottleneck dimension for 1d squeeze-excite
"""
assert in_channels >= 7
super().__init__()
# The ScaleGrad module is there to prevent the gradients
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
# exceeding the range of fp16 when using automatic mixed precision (amp)
# training. (The second one is necessary to stop its bias from getting
# a too-large gradient).
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=(0, 1), # (time, freq)
),
ScaleGrad(0.2),
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
SwooshR(),
nn.Conv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
padding=0,
),
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
SwooshR(),
nn.Conv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=(1, 2), # (time, freq)
),
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
SwooshR(),
)
# just one convnext layer
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
out_width = (((in_channels - 1) // 2) - 1) // 2
self.out = nn.Linear(out_width * layer3_channels, out_channels)
# use a larger than normal grad_scale on this whitening module; there is
# only one such module, so there is not a concern about adding together
# many copies of this extra gradient term.
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=ScheduledFloat(
(0.0, 4.0), (20000.0, 8.0), default=4.0
),
prob=(0.025, 0.25),
grad_scale=0.02,
)
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
# getting large, there is an unnecessary degree of freedom.
self.out_norm = BiasNorm(out_channels)
self.dropout = Dropout3(dropout, shared_dim=1)
def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
- output lengths, of shape (batch_size,)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
# gradients.
x = self.conv(x)
x = self.convnext(x)
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_whiten(x)
x = self.out_norm(x)
x = self.dropout(x)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x_lens = (x_lens - 7) // 2
assert x.size(1) == x_lens.max().item()
return x, x_lens

View File

@ -64,6 +64,7 @@ from zipformer import Zipformer2
from scaling import ScheduledFloat
from decoder import Decoder
from joiner import Joiner
from subsampling import Conv2dSubsampling
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
@ -525,29 +526,46 @@ def get_params() -> AttributeDict:
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Zipformer and Transformer
def to_int_tuple(s: str):
def _to_int_tuple(s: str):
return tuple(map(int, s.split(',')))
def get_encoder_embed(params: AttributeDict) -> nn.Module:
# encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, (T - 7) // 2, encoder_dims).
# That is, it does two things simultaneously:
# (1) subsampling: T -> (T - 7) // 2
# (2) embedding: num_features -> encoder_dims
# In the normal configuration, we will downsample once more at the end
# by a factor of 2, and most of the encoder stacks will run at a lower
# sampling rate.
encoder_embed = Conv2dSubsampling(
in_channels=params.feature_dim,
out_channels=_to_int_tuple(params.encoder_dim)[0],
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
)
return encoder_embed
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Zipformer2(
num_features=params.feature_dim,
output_downsampling_factor=2,
downsampling_factor=to_int_tuple(params.downsampling_factor),
num_encoder_layers=to_int_tuple(params.num_encoder_layers),
encoder_dim=to_int_tuple(params.encoder_dim),
encoder_unmasked_dim=to_int_tuple(params.encoder_unmasked_dim),
query_head_dim=to_int_tuple(params.query_head_dim),
pos_head_dim=to_int_tuple(params.pos_head_dim),
value_head_dim=to_int_tuple(params.value_head_dim),
downsampling_factor=_to_int_tuple(params.downsampling_factor),
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
encoder_dim=_to_int_tuple(params.encoder_dim),
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
query_head_dim=_to_int_tuple(params.query_head_dim),
pos_head_dim=_to_int_tuple(params.pos_head_dim),
value_head_dim=_to_int_tuple(params.value_head_dim),
pos_dim=params.pos_dim,
num_heads=to_int_tuple(params.num_heads),
feedforward_dim=to_int_tuple(params.feedforward_dim),
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
num_heads=_to_int_tuple(params.num_heads),
feedforward_dim=_to_int_tuple(params.feedforward_dim),
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
causal=params.causal,
chunk_size=to_int_tuple(params.chunk_size),
left_context_frames=to_int_tuple(params.left_context_frames),
chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames),
)
return encoder
@ -564,7 +582,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
encoder_dim=int(max(params.encoder_dim.split(','))),
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
@ -573,11 +591,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder_embed=encoder_embed,
encoder=encoder,
decoder=decoder,
joiner=joiner,

View File

@ -18,7 +18,6 @@
import copy
import math
import warnings
import itertools
from typing import List, Optional, Tuple, Union
import logging
import torch
@ -28,13 +27,8 @@ from scaling import (
Balancer,
BiasNorm,
Dropout2,
Dropout3,
SwooshL,
SwooshR,
ChunkCausalDepthwiseConv1d,
ActivationDropoutAndLinear,
ScaledConv1d,
ScaledConv2d,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
@ -44,13 +38,9 @@ from scaling import (
FloatLike,
limit_param_value,
convert_num_channels,
ScaleGrad,
)
from torch import Tensor, nn
from icefall.utils import make_pad_mask
from icefall.dist import get_rank
class Zipformer2(EncoderInterface):
"""
@ -60,8 +50,6 @@ class Zipformer2(EncoderInterface):
as downsampling_factor if they are single ints or one-element tuples. The length of
downsampling_factor defines the number of stacks.
num_features (int): Number of input features, e.g. 40.
output_downsampling_factor (int): how much to downsample at the output. Note:
we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
You should probably leave this at 2.
@ -104,7 +92,6 @@ class Zipformer2(EncoderInterface):
"""
def __init__(
self,
num_features: int,
output_downsampling_factor: int = 2,
downsampling_factor: Tuple[int] = (2, 4),
encoder_dim: Union[int, Tuple[int]] = 384,
@ -140,7 +127,6 @@ class Zipformer2(EncoderInterface):
assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
return x
self.num_features = num_features # int
self.output_downsampling_factor = output_downsampling_factor # int
self.downsampling_factor = downsampling_factor # tuple
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
@ -160,18 +146,6 @@ class Zipformer2(EncoderInterface):
for u,d in zip(encoder_unmasked_dim, encoder_dim):
assert u <= d
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, (T - 7) // 2, encoder_dims).
# That is, it does two things simultaneously:
# (1) subsampling: T -> (T - 7) // 2
# (2) embedding: num_features -> encoder_dims
# In the normal configuration, we will downsample once more at the end
# by a factor of 2, and most of the encoder stacks will run at a lower
# sampling rate.
self.encoder_embed = Conv2dSubsampling(num_features, encoder_dim[0],
dropout=dropout)
# each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
encoders = []
@ -297,7 +271,9 @@ class Zipformer2(EncoderInterface):
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor,
self, x: torch.Tensor,
x_lens: torch.Tensor,
src_key_padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
@ -306,34 +282,15 @@ class Zipformer2(EncoderInterface):
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
chunk_size: Number of frames per chunk (only set this if causal == True).
Must divide all elements of downsampling_factor. At 50hz frame
rate, i.e. after encoder_embed. If not specified, no chunking.
left_context_chunks: Number of left-context chunks for each chunk (affects
attention mask); only set this if chunk_size specified. If -1, there
is no limit on the left context. If not -1, require:
left_context_chunks * context_size >= downsampling_factor[i] *
cnn_module_kernel[i] // 2.
src_key_padding_mask:
The mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
Returns:
Return a tuple containing 2 tensors:
- embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim))
- lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding.
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x = self.encoder_embed(x)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
lengths = (x_lens - 7) // 2
assert x.size(0) == lengths.max().item()
src_key_padding_mask = make_pad_mask(lengths)
outputs = []
feature_masks = self.get_feature_masks(x)
@ -379,9 +336,7 @@ class Zipformer2(EncoderInterface):
assert self.output_downsampling_factor == 2
with warnings.catch_warnings():
warnings.simplefilter("ignore")
lengths = (lengths + 1) // 2
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
lengths = (x_lens + 1) // 2
return x, lengths
@ -700,12 +655,9 @@ class Zipformer2EncoderLayer(nn.Module):
src_key_padding_mask=src_key_padding_mask),
float(self.conv_skip_rate))
src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)),
float(self.ff3_skip_rate))
src = self.balancer1(src)
src = self.norm(src)
@ -1686,244 +1638,13 @@ class ScalarMultiply(nn.Module):
return x * self.scale
class ConvNeXt(nn.Module):
"""
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
"""
def __init__(self,
channels: int,
hidden_ratio: int = 3,
kernel_size: Tuple[int, int] = (7, 7),
layerdrop_rate: FloatLike = None):
super().__init__()
padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
hidden_channels = channels * hidden_ratio
if layerdrop_rate is None:
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
self.layerdrop_rate = layerdrop_rate
self.depthwise_conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
groups=channels,
kernel_size=kernel_size,
padding=padding)
self.pointwise_conv1 = nn.Conv2d(
in_channels=channels,
out_channels=hidden_channels,
kernel_size=1)
self.hidden_balancer = Balancer(hidden_channels,
channel_dim=1,
min_positive=0.3,
max_positive=1.0,
min_abs=0.75,
max_abs=5.0)
self.activation = SwooshL()
self.pointwise_conv2 = ScaledConv2d(
in_channels=hidden_channels,
out_channels=channels,
kernel_size=1,
initial_scale=0.01)
self.out_balancer = Balancer(
channels, channel_dim=1,
min_positive=0.4, max_positive=0.6,
min_abs=1.0, max_abs=6.0,
)
self.out_whiten = Whiten(num_groups=1,
whitening_limit=5.0,
prob=(0.025, 0.25),
grad_scale=0.01)
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not self.training:
return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate)
if layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
else:
mask = None
# turns out this caching idea does not work with --world-size > 1
#return caching_eval(self.forward_internal, x, mask)
return self.forward_internal(x, mask)
def forward_internal(self,
x: Tensor,
layer_skip_mask: Optional[Tensor] = None) -> Tensor:
"""
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
The returned value has the same shape as x.
"""
bypass = x
x = self.depthwise_conv(x)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
if layer_skip_mask is not None:
x = x * layer_skip_mask
x = bypass + x
x = self.out_balancer(x)
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
x = self.out_whiten(x)
x = x.transpose(1, 3) # (N, C, H, W)
return x
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = (T-3)//2 - 2 == (T-7)//2
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
dropout: FloatLike = 0.1,
) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >=7, in_channels >=7
out_channels
Output dim. The output shape is (N, (T-3)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
bottleneck:
bottleneck dimension for 1d squeeze-excite
"""
assert in_channels >= 7
super().__init__()
# The ScaleGrad module is there to prevent the gradients
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
# exceeding the range of fp16 when using automatic mixed precision (amp)
# training. (The second one is necessary to stop its bias from getting
# a too-large gradient).
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=(0, 1), # (time, freq)
),
ScaleGrad(0.2),
Balancer(layer1_channels,
channel_dim=1,
max_abs=1.0),
SwooshR(),
nn.Conv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
padding=0,
),
Balancer(layer2_channels,
channel_dim=1,
max_abs=4.0),
SwooshR(),
nn.Conv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=(1, 2), # (time, freq)
),
Balancer(layer3_channels,
channel_dim=1,
max_abs=4.0),
SwooshR(),
)
cur_width = (in_channels - 1) // 2
# just one convnext layer
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
out_width = (((in_channels - 1) // 2) - 1) // 2
self.out = nn.Linear(out_width * layer3_channels, out_channels)
# use a larger than normal grad_scale on this whitening module; there is
# only one such module, so there is not a concern about adding together
# many copies of this extra gradient term.
self.out_whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(4.0),
prob=(0.025, 0.25),
grad_scale=0.02)
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
# getting large, there is an unnecessary degree of freedom.
self.out_norm = BiasNorm(out_channels)
self.dropout = Dropout3(dropout, shared_dim=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
Returns:
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
# gradients.
x = self.conv(x)
x = self.convnext(x)
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_whiten(x)
x = self.out_norm(x)
x = self.dropout(x)
return x
def _test_zipformer_main(causal: bool = False):
feature_dim = 50
batch_size = 5
seq_len = 20
feature_dim = 50
# Just make sure the forward pass runs.
c = Zipformer2(
num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4),
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
causal=causal,
chunk_size=(4,) if causal else (-1,),
left_context_frames=(64,)
@ -1932,19 +1653,18 @@ def _test_zipformer_main(causal: bool = False):
seq_len = 20
# Just make sure the forward pass runs.
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.randn(seq_len, batch_size, 64),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)
f[0].sum().backward()
c.eval()
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.randn(seq_len, batch_size, 64),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)
f # to remove flake8 warnings
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)