mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
Refactor decoder and joiner to remove extra nn.Linear().
This commit is contained in:
parent
7d1b064c96
commit
9071b1420d
@ -1,100 +0,0 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
|
||||||
"""This class modifies the stateless decoder from the following paper:
|
|
||||||
|
|
||||||
RNN-transducer with stateless prediction network
|
|
||||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
|
||||||
|
|
||||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
|
||||||
network. Different from the above paper, it adds an extra Conv1d
|
|
||||||
right after the embedding layer.
|
|
||||||
|
|
||||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size: int,
|
|
||||||
embedding_dim: int,
|
|
||||||
blank_id: int,
|
|
||||||
context_size: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
vocab_size:
|
|
||||||
Number of tokens of the modeling unit including blank.
|
|
||||||
embedding_dim:
|
|
||||||
Dimension of the input embedding.
|
|
||||||
blank_id:
|
|
||||||
The ID of the blank symbol.
|
|
||||||
context_size:
|
|
||||||
Number of previous words to use to predict the next word.
|
|
||||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.embedding = nn.Embedding(
|
|
||||||
num_embeddings=vocab_size,
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
padding_idx=blank_id,
|
|
||||||
)
|
|
||||||
self.blank_id = blank_id
|
|
||||||
|
|
||||||
assert context_size >= 1, context_size
|
|
||||||
self.context_size = context_size
|
|
||||||
if context_size > 1:
|
|
||||||
self.conv = nn.Conv1d(
|
|
||||||
in_channels=embedding_dim,
|
|
||||||
out_channels=embedding_dim,
|
|
||||||
kernel_size=context_size,
|
|
||||||
padding=0,
|
|
||||||
groups=embedding_dim,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.output_linear = nn.Linear(embedding_dim, vocab_size)
|
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
y:
|
|
||||||
A 2-D tensor of shape (N, U) with blank prepended.
|
|
||||||
need_pad:
|
|
||||||
True to left pad the input. Should be True during training.
|
|
||||||
False to not pad the input. Should be False during inference.
|
|
||||||
Returns:
|
|
||||||
Return a tensor of shape (N, U, embedding_dim).
|
|
||||||
"""
|
|
||||||
embedding_out = self.embedding(y)
|
|
||||||
if self.context_size > 1:
|
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
|
||||||
if need_pad is True:
|
|
||||||
embedding_out = F.pad(
|
|
||||||
embedding_out, pad=(self.context_size - 1, 0)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# During inference time, there is no need to do extra padding
|
|
||||||
# as we only need one output
|
|
||||||
assert embedding_out.size(-1) == self.context_size
|
|
||||||
embedding_out = self.conv(embedding_out)
|
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
|
||||||
embedding_out = self.output_linear(F.relu(embedding_out))
|
|
||||||
return embedding_out
|
|
@ -0,0 +1 @@
|
|||||||
|
../transducer_stateless/decoder.py
|
@ -20,11 +20,20 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
def __init__(self, input_dim: int, inner_dim: int, output_dim: int):
|
def __init__(self, input_dim: int, output_dim: int):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_dim:
|
||||||
|
Input dim of the joiner. It should be equal
|
||||||
|
to the output dim of the encoder and decoder.
|
||||||
|
output_dim:
|
||||||
|
Output dim of the joiner. It should be equal
|
||||||
|
to the vocab_size.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
self.inner_linear = nn.Linear(input_dim, inner_dim)
|
self.output_dim = output_dim
|
||||||
self.output_linear = nn.Linear(inner_dim, output_dim)
|
self.output_linear = nn.Linear(input_dim, output_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||||
@ -40,11 +49,10 @@ class Joiner(nn.Module):
|
|||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||||
assert encoder_out.shape == decoder_out.shape
|
assert encoder_out.shape == decoder_out.shape
|
||||||
|
assert encoder_out.size(-1) == self.input_dim
|
||||||
|
|
||||||
logit = encoder_out + decoder_out
|
x = encoder_out + decoder_out
|
||||||
|
activations = torch.tanh(x)
|
||||||
|
logits = self.output_linear(activations)
|
||||||
|
|
||||||
logit = self.inner_linear(torch.tanh(logit))
|
return logits
|
||||||
|
|
||||||
output = self.output_linear(F.relu(logit))
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
@ -46,8 +46,8 @@ class Transducer(nn.Module):
|
|||||||
is (N, U) and its output shape is (N, U, C). It should contain
|
is (N, U) and its output shape is (N, U, C). It should contain
|
||||||
one attribute: `blank_id`.
|
one attribute: `blank_id`.
|
||||||
joiner:
|
joiner:
|
||||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
It has two inputs with shapes: (N, T, U, C) and (N, T, U, C). Its
|
||||||
output shape is (N, T, U, C). Note that its output contains
|
output shape is also (N, T, U, C). Note that its output contains
|
||||||
unnormalized probs, i.e., not processed by log-softmax.
|
unnormalized probs, i.e., not processed by log-softmax.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -246,6 +246,7 @@ def get_params() -> AttributeDict:
|
|||||||
"log_diagnostics": False,
|
"log_diagnostics": False,
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
|
"encoder_out_dim": 512,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
@ -267,7 +268,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
# TODO: We can add an option to switch between Conformer and Transformer
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_dim=params.vocab_size,
|
output_dim=params.encoder_out_dim,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
d_model=params.attention_dim,
|
d_model=params.attention_dim,
|
||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
@ -279,9 +280,12 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
|
# Note: We set the embedding_dim of the decoder to
|
||||||
|
# vocab_size so that its output can be added with
|
||||||
|
# that of the encoder
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.embedding_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
context_size=params.context_size,
|
context_size=params.context_size,
|
||||||
)
|
)
|
||||||
@ -290,8 +294,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.vocab_size,
|
input_dim=params.encoder_out_dim,
|
||||||
inner_dim=params.embedding_dim,
|
|
||||||
output_dim=params.vocab_size,
|
output_dim=params.vocab_size,
|
||||||
)
|
)
|
||||||
return joiner
|
return joiner
|
||||||
|
Loading…
x
Reference in New Issue
Block a user