Fangjun Kuang 95af039733
RNN-T training for yesno. (#141)
* RNN-T training for yesno.

* Rename Jointer to Joiner.
2021-12-07 21:44:37 +08:00

88 lines
2.9 KiB
Python

# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
# We use a TDNN model as encoder, as it works very well with CTC training
# for this tiny dataset.
class Tdnn(nn.Module):
def __init__(self, num_features: int, output_dim: int):
"""
Args:
num_features:
Model input dimension.
ouput_dim:
Model output dimension
"""
super().__init__()
# Note: We don't use paddings inside conv layers
self.tdnn = nn.Sequential(
nn.Conv1d(
in_channels=num_features,
out_channels=32,
kernel_size=3,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
nn.Conv1d(
in_channels=32,
out_channels=32,
kernel_size=5,
dilation=2,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
nn.Conv1d(
in_channels=32,
out_channels=32,
kernel_size=5,
dilation=4,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
)
self.output_linear = nn.Linear(in_features=32, out_features=output_dim)
def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
The input tensor with shape (N, T, C)
x_lens:
It contains the number of frames in each utterance in x
before padding.
Returns:
Return a tuple with 2 tensors:
- logits, a tensor of shape (N, T, C)
- logit_lens, a tensor of shape (N,)
"""
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.tdnn(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
logits = self.output_linear(x)
# the first conv layer reduces T by 3-1 frames
# the second layer reduces T by (5-1)*2 frames
# the second layer reduces T by (5-1)*4 frames
# Number of output frames is 2 + 4*2 + 4*4 = 2 + 8 + 16 = 26
x_lens = x_lens - 26
return logits, x_lens