mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
test conformer.encoder_embed
This commit is contained in:
parent
0615fb316f
commit
d69d83a83e
@ -1592,6 +1592,10 @@ class Conv2dSubsampling(nn.Module):
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||
)
|
||||
|
||||
# ncnn support only batch size == 1
|
||||
self.for_ncnn = False
|
||||
self.conv_out_dim = self.out.weight.shape[1]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
@ -1606,8 +1610,13 @@ class Conv2dSubsampling(nn.Module):
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = self.conv(x)
|
||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if torch.jit.is_tracing() and self.for_ncnn:
|
||||
x = self.out(
|
||||
x.transpose(1, 2).contiguous().view(1, -1, self.conv_out_dim)
|
||||
)
|
||||
else:
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.out_norm(x)
|
||||
x = self.out_balancer(x)
|
||||
|
@ -5,80 +5,17 @@ import math
|
||||
import ncnn
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
ScaledConv2d,
|
||||
ScaledLinear,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
layer1_channels = 8
|
||||
layer2_channels = 32
|
||||
layer3_channels = 128
|
||||
in_channels = 80
|
||||
out_channels = 512
|
||||
self.out_channels = out_channels
|
||||
self.conv = nn.Sequential(
|
||||
ScaledConv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
self.out = ScaledLinear(
|
||||
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
||||
)
|
||||
print(self.out.weight.shape)
|
||||
self.out_norm = BasicNorm(out_channels, eps=1, learn_eps=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = self.conv(x)
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
|
||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
# b, c, t, f = x.shape
|
||||
x = self.out(x.contiguous().view(1, -1, 128 * 19))
|
||||
|
||||
x = self.out_norm(x)
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
x = torch.rand(1, 200, 80)
|
||||
f = torch.jit.load("foo/scaled_conv2d.pt")
|
||||
f = torch.jit.load("foo/encoder_embed.pt")
|
||||
|
||||
param = "foo/scaled_conv2d.ncnn.param"
|
||||
model = "foo/scaled_conv2d.ncnn.bin"
|
||||
param = "foo/encoder_embed.ncnn.param"
|
||||
model = "foo/encoder_embed.ncnn.bin"
|
||||
|
||||
with ncnn.Net() as net:
|
||||
net.load_param(param)
|
||||
|
@ -2,88 +2,42 @@
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
ScaledConv2d,
|
||||
ScaledLinear,
|
||||
)
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
layer1_channels = 8
|
||||
layer2_channels = 32
|
||||
layer3_channels = 128
|
||||
in_channels = 80
|
||||
out_channels = 512
|
||||
self.out_channels = out_channels
|
||||
self.conv = nn.Sequential(
|
||||
ScaledConv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
self.out = ScaledLinear(
|
||||
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
||||
)
|
||||
print(self.out.weight.shape)
|
||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||
def get_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.unk_id = 2
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = self.conv(x)
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
params.dynamic_chunk_training = False
|
||||
params.short_chunk_size = 25
|
||||
params.num_left_chunks = 4
|
||||
params.causal_convolution = False
|
||||
|
||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
# b, c, t, f = x.shape
|
||||
x = self.out(x.contiguous().view(1, -1, 128 * 19))
|
||||
|
||||
x = self.out_norm(x)
|
||||
return x
|
||||
model = get_transducer_model(params, enable_giga=False)
|
||||
return model
|
||||
|
||||
|
||||
def generate_scaled_conv2d():
|
||||
print("generating")
|
||||
f = Foo()
|
||||
f.eval()
|
||||
f = convert_scaled_to_non_scaled(f)
|
||||
def test_encoder_embedding():
|
||||
model = get_model()
|
||||
model = convert_scaled_to_non_scaled(model)
|
||||
|
||||
f = model.encoder.encoder_embed
|
||||
f.for_ncnn = True
|
||||
print(f)
|
||||
torch.save(f.state_dict(), "f.pt")
|
||||
x = torch.rand(1, 100, 80) # NTC
|
||||
m = torch.jit.trace(f, x)
|
||||
m.save("foo/scaled_conv2d.pt")
|
||||
m.save("foo/encoder_embed.pt")
|
||||
print(m.graph)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
generate_scaled_conv2d()
|
||||
test_encoder_embedding()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user