icefall/egs/yesno/ASR/tdnn/model.py
Fangjun Kuang 57cb611665
[yesno] Remove padding in TDNN (#21)
* Disable SpecAug for yesno.

Also replace Adam with SGD.

* Remove padding in the model to make the results reproducible.
2021-08-23 15:59:36 +08:00

82 lines
2.1 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corp. (author: Fangjun Kuang)
import torch
import torch.nn as nn
class Tdnn(nn.Module):
def __init__(self, num_features: int, num_classes: int):
"""
Args:
num_features:
Model input dimension.
num_classes:
Model output dimension
"""
super().__init__()
self.tdnn = nn.Sequential(
nn.Conv1d(
in_channels=num_features,
out_channels=32,
kernel_size=3,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
nn.Conv1d(
in_channels=32,
out_channels=32,
kernel_size=5,
dilation=2,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
nn.Conv1d(
in_channels=32,
out_channels=32,
kernel_size=5,
dilation=4,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
)
self.output_linear = nn.Linear(in_features=32, out_features=num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
The input tensor with shape [N, T, C]
Returns:
The output tensor has shape [N, T, C]
"""
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
x = self.tdnn(x)
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
x = self.output_linear(x)
x = nn.functional.log_softmax(x, dim=-1)
return x
def test_tdnn():
num_features = 23
num_classes = 4
model = Tdnn(num_features=num_features, num_classes=num_classes)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
N = 2
T = 100
C = num_features
x = torch.randn(N, T, C)
y = model(x)
print(x.shape)
print(y.shape)
if __name__ == "__main__":
test_tdnn()