From 04977175a3efd7ff3fb0722a75df22953fa75333 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 18 Dec 2021 23:54:31 +0800 Subject: [PATCH] Increase the size of the context in the RNN-T decoder. --- .../ASR/transducer_stateless/decode.py | 3 + .../ASR/transducer_stateless/decoder.py | 29 +++++++++ .../ASR/transducer_stateless/test_decoder.py | 61 +++++++++++++++++++ .../ASR/transducer_stateless/train.py | 3 + 4 files changed, 96 insertions(+) create mode 100755 egs/librispeech/ASR/transducer_stateless/test_decoder.py diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 2fa5cc55e..d51af397a 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -130,6 +130,8 @@ def get_params() -> AttributeDict: "num_encoder_layers": 12, "vgg_frontend": False, "use_feat_batchnorm": True, + # parameters for decoder + "context_size": 2, # tri-gram # decoder params "env_info": get_env_info(), } @@ -158,6 +160,7 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, + context_size=params.context_size, ) return decoder diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index 9d6b3aaf2..0773ce37b 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -16,6 +16,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F class Decoder(nn.Module): @@ -35,6 +36,7 @@ class Decoder(nn.Module): vocab_size: int, embedding_dim: int, blank_id: int, + context_size: int, ): """ Args: @@ -44,6 +46,9 @@ class Decoder(nn.Module): Dimension of the input embedding. blank_id: The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() self.embedding = nn.Embedding( @@ -53,6 +58,18 @@ class Decoder(nn.Module): ) self.blank_id = blank_id + assert context_size >= 1, context_size + self.context_size = context_size + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=embedding_dim, + out_channels=embedding_dim, + kernel_size=context_size, + padding=0, + groups=embedding_dim, + bias=False, + ) + def forward(self, y: torch.Tensor) -> torch.Tensor: """ Args: @@ -62,4 +79,16 @@ class Decoder(nn.Module): Return a tensor of shape (N, U, embedding_dim). """ embeding_out = self.embedding(y) + if self.context_size > 1: + embeding_out = embeding_out.permute(0, 2, 1) + if self.training is True: + embeding_out = F.pad( + embeding_out, pad=(self.context_size - 1, 0) + ) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embeding_out.size(-1) == self.context_size + embeding_out = self.conv(embeding_out) + embeding_out = embeding_out.permute(0, 2, 1) return embeding_out diff --git a/egs/librispeech/ASR/transducer_stateless/test_decoder.py b/egs/librispeech/ASR/transducer_stateless/test_decoder.py new file mode 100755 index 000000000..532aaf776 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/test_decoder.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_stateless/test_decoder.py +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from decoder import Decoder + + +def test_decoder(): + vocab_size = 3 + blank_id = 0 + embedding_dim = 128 + context_size = 4 + + decoder = Decoder( + vocab_size=vocab_size, + embedding_dim=embedding_dim, + blank_id=blank_id, + context_size=context_size, + ) + N = 100 + U = 20 + x = torch.randint(low=0, high=vocab_size, size=(N, U)) + y = decoder(x) + assert y.shape == (N, U, embedding_dim) + + # for inference + decoder.eval() + x = torch.randint(low=0, high=vocab_size, size=(N, context_size)) + y = decoder(x) + assert y.shape == (N, 1, embedding_dim) + + +def main(): + test_decoder() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index e20aedf9b..4e0515cbf 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -202,6 +202,8 @@ def get_params() -> AttributeDict: "num_encoder_layers": 12, "vgg_frontend": False, "use_feat_batchnorm": True, + # parameters for decoder + "context_size": 2, # tri-gram # parameters for Noam "weight_decay": 1e-6, "warm_step": 80000, # For the 100h subset, use 8k @@ -233,6 +235,7 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, + context_size=params.context_size, ) return decoder