2022-01-28 20:05:48 +08:00

123 lines
4.8 KiB
Python

# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
embedding_dim: int,
blank_id: int,
context_size: int,
use_right_context: bool = False,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
embedding_dim:
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
use_right_context:
True to use right context, which is usefull to implement a
backward decoder, only used for training.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
self.use_right_context = use_right_context
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=embedding_dim,
out_channels=embedding_dim,
kernel_size=context_size,
padding=0,
groups=embedding_dim,
bias=False,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U) with blank prepended.
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embedding_out = self.embedding(y)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
# Regarding the left or right context we are using,
# if we feed sequence [sos, a, b, c, d] to this decoder, and
# want to predict the sequence [a, b, c, d]. After padding to
# the left with context_size==2, the fed in sequence changes to
# [pad, sos, a, b, c, d], and we use `pad,sos` to predict `a`,
# `sos,a` to predict `b` ..., that is left context.
# if we feed sequence [b, c, d, blk, blk] to this decoder,
# and want to predict the sequence [a, b, c, d]. After padding
# to the right with context_size==2, the fed in sequence changes
# to [b, c, d, blk, blk, pad], and we use `b, c` to predict `a`
# `c,d` to predict `b` ..., that is right context.
# This is tricky and not so straightforward, will find better
# implementation later.
if self.use_right_context:
embedding_out = F.pad(
embedding_out, pad=(0, self.context_size - 1)
)
else:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
assert self.use_right_context is False
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
return embedding_out