Implement ConvNorm2d and use it in frontend after convnext
This commit is contained in:
parent
71880409cc
commit
244633660d
@ -616,6 +616,120 @@ class ConvNorm1d(torch.nn.Module):
|
||||
return x * scales
|
||||
|
||||
|
||||
class PositiveConv2d(nn.Conv2d):
|
||||
"""
|
||||
A modified form of nn.Conv2d where the weight parameters are constrained
|
||||
to be positive and there is no bias.
|
||||
"""
|
||||
def __init__(
|
||||
self, *args, min: FloatLike = 0.01, max: FloatLike = 1.0,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs, bias=False)
|
||||
self.min = min
|
||||
self.max = max
|
||||
|
||||
# initialize weight to all positive values.
|
||||
with torch.no_grad():
|
||||
self.weight[:] = 1.0 / self.weight[0][0].numel()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
"""
|
||||
Forward function. Input and returned tensor have shape:
|
||||
(N, C, H, W)
|
||||
i.e. (batch_size, num_channels, height, width)
|
||||
"""
|
||||
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max))
|
||||
# make absolutely sure there are no negative values. For parameter-averaging-related
|
||||
# reasons, we prefer to also use limit_param_value to make sure the weights stay
|
||||
# positive.
|
||||
weight = weight.abs()
|
||||
|
||||
if self.padding_mode != 'zeros':
|
||||
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||
weight, self.bias, self.stride,
|
||||
_pair(0), self.dilation, self.groups)
|
||||
return F.conv2d(input, weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class ConvNorm2d(torch.nn.Module):
|
||||
"""
|
||||
This is like BasicNorm except the denominator is summed over time using
|
||||
convolution with positive weights.
|
||||
|
||||
|
||||
Args:
|
||||
num_channels: the number of channels, e.g. 512.
|
||||
eps: the initial "epsilon" that we add as ballast in:
|
||||
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
||||
Note: our epsilon is actually large, but we keep the name
|
||||
to indicate the connection with conventional LayerNorm.
|
||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||
at the initial value.
|
||||
eps_min: float
|
||||
eps_max: float
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
eps: float = 0.25,
|
||||
learn_eps: bool = True,
|
||||
eps_min: float = -3.0,
|
||||
eps_max: float = 3.0,
|
||||
conv_min: float = 0.05,
|
||||
conv_max: float = 1.0,
|
||||
kernel_size: Tuple[int, int] = (3, 3),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
if learn_eps:
|
||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||
else:
|
||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||
self.eps_min = eps_min
|
||||
self.eps_max = eps_max
|
||||
pad = (kernel_size[0] // 2, kernel_size[1] // 2)
|
||||
# it has bias=False.
|
||||
self.conv = PositiveConv2d(1, 1, kernel_size=kernel_size, padding=pad,
|
||||
min=conv_min, max=conv_max)
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
x shape: (N, C, H, W)
|
||||
"""
|
||||
assert x.ndim == 4 and x.shape[1] == self.num_channels
|
||||
eps = self.eps
|
||||
if self.training and random.random() < 0.25:
|
||||
# with probability 0.25, in training mode, clamp eps between the min
|
||||
# and max; this will encourage it to learn parameters within the
|
||||
# allowed range by making parameters that are outside the allowed
|
||||
# range noisy.
|
||||
|
||||
# gradients to allow the parameter to get back into the allowed
|
||||
# region if it happens to exit it.
|
||||
eps = torch.clamp(eps, min=self.eps_min, max=self.eps_max)
|
||||
|
||||
# sqnorms: (N, 1, H, W)
|
||||
sqnorms = (
|
||||
torch.mean(x ** 2, dim=1, keepdim=True)
|
||||
)
|
||||
# 'counts' is a mechanism to correct for edge effects.
|
||||
# TODO: key-padding mask
|
||||
|
||||
counts = torch.ones_like(sqnorms)
|
||||
#if src_key_padding_mask is not None:
|
||||
# counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0)
|
||||
#sqnorms = sqnorms * counts
|
||||
sqnorms = self.conv(sqnorms)
|
||||
# the clamping is to avoid division by zero for padding frames.
|
||||
counts = torch.clamp(self.conv(counts), min=0.01)
|
||||
# scales: (N, 1, H, W)
|
||||
scales = (sqnorms / counts + eps.exp()) ** -0.5
|
||||
return x * scales
|
||||
|
||||
|
||||
|
||||
|
||||
def ScaledLinear(*args,
|
||||
|
||||
@ -28,6 +28,7 @@ from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
ConvNorm1d,
|
||||
ConvNorm2d,
|
||||
Dropout2,
|
||||
MaxEig,
|
||||
DoubleSwish,
|
||||
@ -1792,8 +1793,8 @@ class Conv2dSubsampling(nn.Module):
|
||||
|
||||
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
|
||||
ConvNeXt(layer2_channels),
|
||||
BasicNorm(layer2_channels,
|
||||
channel_dim=1))
|
||||
ConvNorm2d(layer2_channels,
|
||||
kernel_size=(15, 7))) # (time, freq)
|
||||
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
@ -1812,8 +1813,8 @@ class Conv2dSubsampling(nn.Module):
|
||||
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
|
||||
ConvNeXt(layer3_channels),
|
||||
ConvNeXt(layer3_channels),
|
||||
BasicNorm(layer3_channels,
|
||||
channel_dim=1))
|
||||
ConvNorm2d(layer3_channels,
|
||||
kernel_size=(15, 5))) # (time, freq)
|
||||
|
||||
|
||||
out_height = (((in_channels - 1) // 2) - 1) // 2
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user