mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
* Disable SpecAug for yesno. Also replace Adam with SGD. * Remove padding in the model to make the results reproducible.
82 lines
2.1 KiB
Python
Executable File
82 lines
2.1 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) 2021 Xiaomi Corp. (author: Fangjun Kuang)
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Tdnn(nn.Module):
|
|
def __init__(self, num_features: int, num_classes: int):
|
|
"""
|
|
Args:
|
|
num_features:
|
|
Model input dimension.
|
|
num_classes:
|
|
Model output dimension
|
|
"""
|
|
super().__init__()
|
|
|
|
self.tdnn = nn.Sequential(
|
|
nn.Conv1d(
|
|
in_channels=num_features,
|
|
out_channels=32,
|
|
kernel_size=3,
|
|
),
|
|
nn.ReLU(inplace=True),
|
|
nn.BatchNorm1d(num_features=32, affine=False),
|
|
nn.Conv1d(
|
|
in_channels=32,
|
|
out_channels=32,
|
|
kernel_size=5,
|
|
dilation=2,
|
|
),
|
|
nn.ReLU(inplace=True),
|
|
nn.BatchNorm1d(num_features=32, affine=False),
|
|
nn.Conv1d(
|
|
in_channels=32,
|
|
out_channels=32,
|
|
kernel_size=5,
|
|
dilation=4,
|
|
),
|
|
nn.ReLU(inplace=True),
|
|
nn.BatchNorm1d(num_features=32, affine=False),
|
|
)
|
|
self.output_linear = nn.Linear(in_features=32, out_features=num_classes)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
x:
|
|
The input tensor with shape [N, T, C]
|
|
|
|
Returns:
|
|
The output tensor has shape [N, T, C]
|
|
"""
|
|
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
|
|
x = self.tdnn(x)
|
|
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
|
x = self.output_linear(x)
|
|
x = nn.functional.log_softmax(x, dim=-1)
|
|
return x
|
|
|
|
|
|
def test_tdnn():
|
|
num_features = 23
|
|
num_classes = 4
|
|
model = Tdnn(num_features=num_features, num_classes=num_classes)
|
|
num_param = sum([p.numel() for p in model.parameters()])
|
|
print(f"Number of model parameters: {num_param}")
|
|
N = 2
|
|
T = 100
|
|
C = num_features
|
|
x = torch.randn(N, T, C)
|
|
y = model(x)
|
|
print(x.shape)
|
|
print(y.shape)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_tdnn()
|