mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Finishing testing encoder_embedding.
This commit is contained in:
parent
b406c1beff
commit
280e1c0312
@ -123,16 +123,16 @@ class BasicNorm(torch.nn.Module):
|
||||
doesn't have to do this trick. We make the "eps" learnable.
|
||||
|
||||
Args:
|
||||
num_channels: the number of channels, e.g. 512.
|
||||
num_channels: the number of channels, e.g. 512.
|
||||
channel_dim: the axis/dimension corresponding to the channel,
|
||||
interprted as an offset from the input's ndim if negative.
|
||||
shis is NOT the num_channels; it should typically be one of
|
||||
interpreted as an offset from the input's ndim if negative.
|
||||
This is NOT the num_channels; it should typically be one of
|
||||
{-2, -1, 0, 1, 2, 3}.
|
||||
eps: the initial "epsilon" that we add as ballast in:
|
||||
eps: the initial "epsilon" that we add as ballast in:
|
||||
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
||||
Note: our epsilon is actually large, but we keep the name
|
||||
to indicate the connection with conventional LayerNorm.
|
||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||
at the initial value.
|
||||
"""
|
||||
|
||||
|
@ -28,7 +28,37 @@ import re
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
|
||||
from scaling import (
|
||||
BasicNorm,
|
||||
ScaledConv1d,
|
||||
ScaledConv2d,
|
||||
ScaledEmbedding,
|
||||
ScaledLinear,
|
||||
)
|
||||
|
||||
|
||||
class NonScaledNorm(nn.Module):
|
||||
"""See BasicNorm for doc"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
eps_exp: float,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.eps_exp = eps_exp
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if not torch.jit.is_tracing():
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (
|
||||
torch.mean(x.pow(2), dim=self.channel_dim, keepdim=True)
|
||||
+ self.eps_exp
|
||||
).pow(-0.5)
|
||||
return x * scales
|
||||
|
||||
|
||||
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
|
||||
@ -164,6 +194,16 @@ def scaled_embedding_to_embedding(
|
||||
return embedding
|
||||
|
||||
|
||||
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
|
||||
assert isinstance(basic_norm, BasicNorm), type(BasicNorm)
|
||||
norm = NonScaledNorm(
|
||||
num_channels=basic_norm.num_channels,
|
||||
eps_exp=basic_norm.eps.data.exp().item(),
|
||||
channel_dim=basic_norm.channel_dim,
|
||||
)
|
||||
return norm
|
||||
|
||||
|
||||
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
||||
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
|
||||
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
|
||||
@ -196,6 +236,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
||||
d[name] = scaled_conv2d_to_conv2d(m)
|
||||
elif isinstance(m, ScaledEmbedding):
|
||||
d[name] = scaled_embedding_to_embedding(m)
|
||||
elif isinstance(m, BasicNorm):
|
||||
d[name] = convert_basic_norm(m)
|
||||
|
||||
for k, v in d.items():
|
||||
if "." in k:
|
||||
|
@ -1,54 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import ncnn
|
||||
import numpy as np
|
||||
import torch
|
||||
from scaling import ScaledConv2d
|
||||
from scaling_converter import scaled_conv2d_to_conv2d
|
||||
import torch.nn as nn
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
ScaledConv2d,
|
||||
ScaledLinear,
|
||||
)
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
|
||||
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
layer1_channels = 8
|
||||
layer2_channels = 32
|
||||
layer3_channels = 128
|
||||
in_channels = 80
|
||||
out_channels = 512
|
||||
self.out_channels = out_channels
|
||||
self.conv = nn.Sequential(
|
||||
ScaledConv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
self.out = ScaledLinear(
|
||||
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
||||
)
|
||||
print(self.out.weight.shape)
|
||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = self.conv(x)
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
|
||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
# b, c, t, f = x.shape
|
||||
x = self.out(x.contiguous().view(1, -1, 128 * 19))
|
||||
|
||||
x = self.out_norm(x)
|
||||
return x
|
||||
|
||||
|
||||
def generate_scaled_conv2d():
|
||||
f = ScaledConv2d(in_channels=1, out_channels=2, kernel_size=3, padding=1)
|
||||
f = scaled_conv2d_to_conv2d(f)
|
||||
print("generating")
|
||||
f = Foo()
|
||||
f.eval()
|
||||
f = convert_scaled_to_non_scaled(f)
|
||||
print(f)
|
||||
x = torch.rand(1, 1, 6, 8) # NCHW
|
||||
torch.save(f.state_dict(), "f.pt")
|
||||
x = torch.rand(1, 100, 80) # NTC
|
||||
m = torch.jit.trace(f, x)
|
||||
m.save("foo/scaled_conv2d.pt")
|
||||
print(m.graph)
|
||||
|
||||
|
||||
def compare_scaled_conv2d():
|
||||
param = "foo/scaled_conv2d.ncnn.param"
|
||||
model = "foo/scaled_conv2d.ncnn.bin"
|
||||
|
||||
with ncnn.Net() as net:
|
||||
with net.create_extractor() as ex:
|
||||
net = ncnn.Net()
|
||||
net.load_param(param)
|
||||
net.load_model(model)
|
||||
|
||||
ex = net.create_extractor()
|
||||
x = torch.rand(1, 6, 5) # CHW
|
||||
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||
ret, out0 = ex.extract("out0")
|
||||
assert ret == 0
|
||||
out0 = np.array(out0)
|
||||
out0 = torch.from_numpy(out0)
|
||||
|
||||
m = torch.jit.load("foo/scaled_conv2d.pt")
|
||||
y = m(x.unsqueeze(0)).squeeze(0)
|
||||
|
||||
assert torch.allclose(out0, y, atol=1e-3), (out0 - y).abs().max()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
if not Path("foo/scaled_conv2d.ncnn.param").is_file():
|
||||
generate_scaled_conv2d()
|
||||
else:
|
||||
compare_scaled_conv2d()
|
||||
generate_scaled_conv2d()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user