Increase the size of the context in the RNN-T decoder.

This commit is contained in:
Fangjun Kuang 2021-12-18 23:54:31 +08:00
parent cb04c8a750
commit 04977175a3
4 changed files with 96 additions and 0 deletions

View File

@ -130,6 +130,8 @@ def get_params() -> AttributeDict:
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
# decoder params
"env_info": get_env_info(),
}
@ -158,6 +160,7 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder

View File

@ -16,6 +16,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
@ -35,6 +36,7 @@ class Decoder(nn.Module):
vocab_size: int,
embedding_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
@ -44,6 +46,9 @@ class Decoder(nn.Module):
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
@ -53,6 +58,18 @@ class Decoder(nn.Module):
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=embedding_dim,
out_channels=embedding_dim,
kernel_size=context_size,
padding=0,
groups=embedding_dim,
bias=False,
)
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
@ -62,4 +79,16 @@ class Decoder(nn.Module):
Return a tensor of shape (N, U, embedding_dim).
"""
embeding_out = self.embedding(y)
if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1)
if self.training is True:
embeding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embeding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out)
embeding_out = embeding_out.permute(0, 2, 1)
return embeding_out

View File

@ -0,0 +1,61 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_stateless/test_decoder.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from decoder import Decoder
def test_decoder():
vocab_size = 3
blank_id = 0
embedding_dim = 128
context_size = 4
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
context_size=context_size,
)
N = 100
U = 20
x = torch.randint(low=0, high=vocab_size, size=(N, U))
y = decoder(x)
assert y.shape == (N, U, embedding_dim)
# for inference
decoder.eval()
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
y = decoder(x)
assert y.shape == (N, 1, embedding_dim)
def main():
test_decoder()
if __name__ == "__main__":
main()

View File

@ -202,6 +202,8 @@ def get_params() -> AttributeDict:
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
# parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k
@ -233,6 +235,7 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder