From d69d83a83e4b698ae050f04653272c020cb6dc77 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 5 Aug 2022 21:28:05 +0800 Subject: [PATCH] test conformer.encoder_embed --- .../pruned_transducer_stateless2/conformer.py | 13 ++- .../ASR/pruned_transducer_stateless3/t2.py | 69 +-------------- .../pruned_transducer_stateless3/test_ncnn.py | 88 +++++-------------- 3 files changed, 35 insertions(+), 135 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index e95360d1d..43697a315 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -1592,6 +1592,10 @@ class Conv2dSubsampling(nn.Module): channel_dim=-1, min_positive=0.45, max_positive=0.55 ) + # ncnn support only batch size == 1 + self.for_ncnn = False + self.conv_out_dim = self.out.weight.shape[1] + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1606,8 +1610,13 @@ class Conv2dSubsampling(nn.Module): x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if torch.jit.is_tracing() and self.for_ncnn: + x = self.out( + x.transpose(1, 2).contiguous().view(1, -1, self.conv_out_dim) + ) + else: + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) x = self.out_norm(x) x = self.out_balancer(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/t2.py b/egs/librispeech/ASR/pruned_transducer_stateless3/t2.py index 934784d69..46e046621 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/t2.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/t2.py @@ -5,80 +5,17 @@ import math import ncnn import numpy as np import torch -import torch.nn as nn -from scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - ScaledConv2d, - ScaledLinear, -) LOG_EPS = math.log(1e-10) -class Foo(nn.Module): - def __init__(self): - super().__init__() - layer1_channels = 8 - layer2_channels = 32 - layer3_channels = 128 - in_channels = 80 - out_channels = 512 - self.out_channels = out_channels - self.conv = nn.Sequential( - ScaledConv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=1, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=2, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ) - self.out = ScaledLinear( - layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels - ) - print(self.out.weight.shape) - self.out_norm = BasicNorm(out_channels, eps=1, learn_eps=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - x = x.permute(0, 2, 1, 3) - - # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) - # b, c, t, f = x.shape - x = self.out(x.contiguous().view(1, -1, 128 * 19)) - - x = self.out_norm(x) - return x - - @torch.no_grad() def main(): x = torch.rand(1, 200, 80) - f = torch.jit.load("foo/scaled_conv2d.pt") + f = torch.jit.load("foo/encoder_embed.pt") - param = "foo/scaled_conv2d.ncnn.param" - model = "foo/scaled_conv2d.ncnn.bin" + param = "foo/encoder_embed.ncnn.param" + model = "foo/encoder_embed.ncnn.bin" with ncnn.Net() as net: net.load_param(param) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn.py index ffdfb6594..d24f50a84 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn.py @@ -2,88 +2,42 @@ import torch -import torch.nn as nn -from scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - ScaledConv2d, - ScaledLinear, -) from scaling_converter import convert_scaled_to_non_scaled +from train import get_params, get_transducer_model -class Foo(nn.Module): - def __init__(self): - super().__init__() - layer1_channels = 8 - layer2_channels = 32 - layer3_channels = 128 - in_channels = 80 - out_channels = 512 - self.out_channels = out_channels - self.conv = nn.Sequential( - ScaledConv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=1, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=2, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ) - self.out = ScaledLinear( - layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels - ) - print(self.out.weight.shape) - self.out_norm = BasicNorm(out_channels, learn_eps=False) +def get_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 - def forward(self, x: torch.Tensor) -> torch.Tensor: - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - x = x.permute(0, 2, 1, 3) + params.dynamic_chunk_training = False + params.short_chunk_size = 25 + params.num_left_chunks = 4 + params.causal_convolution = False - # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) - # b, c, t, f = x.shape - x = self.out(x.contiguous().view(1, -1, 128 * 19)) - - x = self.out_norm(x) - return x + model = get_transducer_model(params, enable_giga=False) + return model -def generate_scaled_conv2d(): - print("generating") - f = Foo() - f.eval() - f = convert_scaled_to_non_scaled(f) +def test_encoder_embedding(): + model = get_model() + model = convert_scaled_to_non_scaled(model) + + f = model.encoder.encoder_embed + f.for_ncnn = True print(f) - torch.save(f.state_dict(), "f.pt") x = torch.rand(1, 100, 80) # NTC m = torch.jit.trace(f, x) - m.save("foo/scaled_conv2d.pt") + m.save("foo/encoder_embed.pt") print(m.graph) @torch.no_grad() def main(): - generate_scaled_conv2d() + test_encoder_embedding() if __name__ == "__main__":