Finishing testing encoder_embedding.

This commit is contained in:
Fangjun Kuang 2022-08-05 20:35:23 +08:00
parent b406c1beff
commit 280e1c0312
3 changed files with 119 additions and 42 deletions

View File

@ -123,16 +123,16 @@ class BasicNorm(torch.nn.Module):
doesn't have to do this trick. We make the "eps" learnable.
Args:
num_channels: the number of channels, e.g. 512.
num_channels: the number of channels, e.g. 512.
channel_dim: the axis/dimension corresponding to the channel,
interprted as an offset from the input's ndim if negative.
shis is NOT the num_channels; it should typically be one of
interpreted as an offset from the input's ndim if negative.
This is NOT the num_channels; it should typically be one of
{-2, -1, 0, 1, 2, 3}.
eps: the initial "epsilon" that we add as ballast in:
eps: the initial "epsilon" that we add as ballast in:
scale = ((input_vec**2).mean() + epsilon)**-0.5
Note: our epsilon is actually large, but we keep the name
to indicate the connection with conventional LayerNorm.
learn_eps: if true, we learn epsilon; if false, we keep it
learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value.
"""

View File

@ -28,7 +28,37 @@ import re
import torch
import torch.nn as nn
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
from scaling import (
BasicNorm,
ScaledConv1d,
ScaledConv2d,
ScaledEmbedding,
ScaledLinear,
)
class NonScaledNorm(nn.Module):
"""See BasicNorm for doc"""
def __init__(
self,
num_channels: int,
eps_exp: float,
channel_dim: int = -1, # CAUTION: see documentation.
):
super().__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.eps_exp = eps_exp
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not torch.jit.is_tracing():
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x.pow(2), dim=self.channel_dim, keepdim=True)
+ self.eps_exp
).pow(-0.5)
return x * scales
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
@ -164,6 +194,16 @@ def scaled_embedding_to_embedding(
return embedding
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
assert isinstance(basic_norm, BasicNorm), type(BasicNorm)
norm = NonScaledNorm(
num_channels=basic_norm.num_channels,
eps_exp=basic_norm.eps.data.exp().item(),
channel_dim=basic_norm.channel_dim,
)
return norm
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
@ -196,6 +236,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
d[name] = scaled_conv2d_to_conv2d(m)
elif isinstance(m, ScaledEmbedding):
d[name] = scaled_embedding_to_embedding(m)
elif isinstance(m, BasicNorm):
d[name] = convert_basic_norm(m)
for k, v in d.items():
if "." in k:

View File

@ -1,54 +1,89 @@
#!/usr/bin/env python3
from pathlib import Path
import ncnn
import numpy as np
import torch
from scaling import ScaledConv2d
from scaling_converter import scaled_conv2d_to_conv2d
import torch.nn as nn
from scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv2d,
ScaledLinear,
)
from scaling_converter import convert_scaled_to_non_scaled
class Foo(nn.Module):
def __init__(self):
super().__init__()
layer1_channels = 8
layer2_channels = 32
layer3_channels = 128
in_channels = 80
out_channels = 512
self.out_channels = out_channels
self.conv = nn.Sequential(
ScaledConv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=1,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
)
self.out = ScaledLinear(
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
)
print(self.out.weight.shape)
self.out_norm = BasicNorm(out_channels, learn_eps=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
x = x.permute(0, 2, 1, 3)
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
# b, c, t, f = x.shape
x = self.out(x.contiguous().view(1, -1, 128 * 19))
x = self.out_norm(x)
return x
def generate_scaled_conv2d():
f = ScaledConv2d(in_channels=1, out_channels=2, kernel_size=3, padding=1)
f = scaled_conv2d_to_conv2d(f)
print("generating")
f = Foo()
f.eval()
f = convert_scaled_to_non_scaled(f)
print(f)
x = torch.rand(1, 1, 6, 8) # NCHW
torch.save(f.state_dict(), "f.pt")
x = torch.rand(1, 100, 80) # NTC
m = torch.jit.trace(f, x)
m.save("foo/scaled_conv2d.pt")
print(m.graph)
def compare_scaled_conv2d():
param = "foo/scaled_conv2d.ncnn.param"
model = "foo/scaled_conv2d.ncnn.bin"
with ncnn.Net() as net:
with net.create_extractor() as ex:
net = ncnn.Net()
net.load_param(param)
net.load_model(model)
ex = net.create_extractor()
x = torch.rand(1, 6, 5) # CHW
ex.input("in0", ncnn.Mat(x.numpy()).clone())
ret, out0 = ex.extract("out0")
assert ret == 0
out0 = np.array(out0)
out0 = torch.from_numpy(out0)
m = torch.jit.load("foo/scaled_conv2d.pt")
y = m(x.unsqueeze(0)).squeeze(0)
assert torch.allclose(out0, y, atol=1e-3), (out0 - y).abs().max()
@torch.no_grad()
def main():
if not Path("foo/scaled_conv2d.ncnn.param").is_file():
generate_scaled_conv2d()
else:
compare_scaled_conv2d()
generate_scaled_conv2d()
if __name__ == "__main__":