Refactor decoder and joiner to remove extra nn.Linear().

This commit is contained in:
Fangjun Kuang 2022-03-09 22:59:01 +08:00
parent 7d1b064c96
commit 9071b1420d
4 changed files with 28 additions and 116 deletions

View File

@ -1,100 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
embedding_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
embedding_dim:
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=embedding_dim,
out_channels=embedding_dim,
kernel_size=context_size,
padding=0,
groups=embedding_dim,
bias=False,
)
self.output_linear = nn.Linear(embedding_dim, vocab_size)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U) with blank prepended.
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embedding_out = self.embedding(y)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = self.output_linear(F.relu(embedding_out))
return embedding_out

View File

@ -0,0 +1 @@
../transducer_stateless/decoder.py

View File

@ -20,11 +20,20 @@ import torch.nn.functional as F
class Joiner(nn.Module): class Joiner(nn.Module):
def __init__(self, input_dim: int, inner_dim: int, output_dim: int): def __init__(self, input_dim: int, output_dim: int):
"""
Args:
input_dim:
Input dim of the joiner. It should be equal
to the output dim of the encoder and decoder.
output_dim:
Output dim of the joiner. It should be equal
to the vocab_size.
"""
super().__init__() super().__init__()
self.input_dim = input_dim
self.inner_linear = nn.Linear(input_dim, inner_dim) self.output_dim = output_dim
self.output_linear = nn.Linear(inner_dim, output_dim) self.output_linear = nn.Linear(input_dim, output_dim)
def forward( def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
@ -40,11 +49,10 @@ class Joiner(nn.Module):
""" """
assert encoder_out.ndim == decoder_out.ndim == 4 assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.shape == decoder_out.shape assert encoder_out.shape == decoder_out.shape
assert encoder_out.size(-1) == self.input_dim
logit = encoder_out + decoder_out x = encoder_out + decoder_out
activations = torch.tanh(x)
logits = self.output_linear(activations)
logit = self.inner_linear(torch.tanh(logit)) return logits
output = self.output_linear(F.relu(logit))
return output

View File

@ -46,8 +46,8 @@ class Transducer(nn.Module):
is (N, U) and its output shape is (N, U, C). It should contain is (N, U) and its output shape is (N, U, C). It should contain
one attribute: `blank_id`. one attribute: `blank_id`.
joiner: joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its It has two inputs with shapes: (N, T, U, C) and (N, T, U, C). Its
output shape is (N, T, U, C). Note that its output contains output shape is also (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax. unnormalized probs, i.e., not processed by log-softmax.
""" """
super().__init__() super().__init__()

View File

@ -246,6 +246,7 @@ def get_params() -> AttributeDict:
"log_diagnostics": False, "log_diagnostics": False,
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4, "subsampling_factor": 4,
"attention_dim": 512, "attention_dim": 512,
"nhead": 8, "nhead": 8,
@ -267,7 +268,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer # TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.vocab_size, output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim, d_model=params.attention_dim,
nhead=params.nhead, nhead=params.nhead,
@ -279,9 +280,12 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
def get_decoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module:
# Note: We set the embedding_dim of the decoder to
# vocab_size so that its output can be added with
# that of the encoder
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim, embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
context_size=params.context_size, context_size=params.context_size,
) )
@ -290,8 +294,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.vocab_size, input_dim=params.encoder_out_dim,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
) )
return joiner return joiner