From 3c89734b79357508c0274b04895f64cc7d65a446 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 28 Dec 2021 11:09:43 +0800 Subject: [PATCH] Use similar number of parameters as conformer encoder. --- egs/librispeech/ASR/README.md | 2 +- egs/librispeech/ASR/transducer_lstm/decode.py | 28 ++++-- .../ASR/transducer_lstm/decoder.py | 99 +------------------ egs/librispeech/ASR/transducer_lstm/model.py | 6 +- egs/librispeech/ASR/transducer_lstm/train.py | 4 +- 5 files changed, 28 insertions(+), 111 deletions(-) mode change 100644 => 120000 egs/librispeech/ASR/transducer_lstm/decoder.py diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index c8ee98d7d..9fd657eec 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -13,7 +13,7 @@ The following table lists the differences among them. |------------------------|-----------|--------------------| | `transducer` | Conformer | LSTM | | `transducer_stateless` | Conformer | Embedding + Conv1d | -| `transducer_lstm ` | LSTM | LSTM | +| `transducer_lstm ` | LSTM | Embedding + Conv1d | The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index fee2a10ad..29d11c7a6 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -32,7 +32,7 @@ Usage: --exp-dir ./transducer_lstm/exp \ --max-duration 100 \ --decoding-method beam_search \ - --beam-size 8 + --beam-size 4 """ @@ -70,14 +70,14 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=77, + default=29, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=55, + default=13, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--beam-size", type=int, - default=5, + default=4, help="Used only when --decoding-method is beam_search", ) @@ -122,6 +122,13 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=3, + help="Maximum number of symbols per frame", + ) + return parser @@ -132,8 +139,8 @@ def get_params() -> AttributeDict: "feature_dim": 80, "encoder_out_dim": 512, "subsampling_factor": 4, - "encoder_hidden_size": 2048, - "num_encoder_layers": 6, + "encoder_hidden_size": 1024, + "num_encoder_layers": 7, "proj_size": 512, "vgg_frontend": False, "env_info": get_env_info(), @@ -237,7 +244,11 @@ def decode_one_batch( encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": - hyp = greedy_search(model=model, encoder_out=encoder_out_i) + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) elif params.decoding_method == "beam_search": hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size @@ -381,6 +392,9 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.decoding_method == "beam_search": params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py deleted file mode 100644 index dca084477..000000000 --- a/egs/librispeech/ASR/transducer_lstm/decoder.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - embedding_dim: int, - blank_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - embedding_dim: - Dimension of the input embedding. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim, - padding_idx=blank_id, - ) - self.blank_id = blank_id - - assert context_size >= 1, context_size - self.context_size = context_size - if context_size > 1: - self.conv = nn.Conv1d( - in_channels=embedding_dim, - out_channels=embedding_dim, - kernel_size=context_size, - padding=0, - groups=embedding_dim, - bias=False, - ) - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U) with blank prepended. - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, embedding_dim). - """ - embeding_out = self.embedding(y) - if self.context_size > 1: - embeding_out = embeding_out.permute(0, 2, 1) - if need_pad is True: - embeding_out = F.pad( - embeding_out, pad=(self.context_size - 1, 0) - ) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embeding_out.size(-1) == self.context_size - embeding_out = self.conv(embeding_out) - embeding_out = embeding_out.permute(0, 2, 1) - return embeding_out diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py new file mode 120000 index 000000000..eada91097 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/decoder.py @@ -0,0 +1 @@ +../transducer_stateless/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index 470c77b62..2f0f9a183 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -49,14 +49,14 @@ class Transducer(nn.Module): decoder: It is the prediction network in the paper. Its input shape is (N, U) and its output shape is (N, U, C). It should contain - one attributes `blank_id`. + one attribute: `blank_id`. joiner: It has two inputs with shapes: (N, T, C) and (N, U, C). Its output shape is (N, T, U, C). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. """ super().__init__() - assert isinstance(encoder, EncoderInterface) + assert isinstance(encoder, EncoderInterface), type(encoder) assert hasattr(decoder, "blank_id") self.encoder = encoder @@ -100,7 +100,7 @@ class Transducer(nn.Module): sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - decoder_out, _ = self.decoder(sos_y_padded) + decoder_out = self.decoder(sos_y_padded) logits = self.joiner(encoder_out, decoder_out) diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index d931b2472..d63d7aebd 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -200,8 +200,8 @@ def get_params() -> AttributeDict: "feature_dim": 80, "encoder_out_dim": 512, "subsampling_factor": 4, - "encoder_hidden_size": 2048, - "num_encoder_layers": 6, + "encoder_hidden_size": 1024, + "num_encoder_layers": 7, "proj_size": 512, "vgg_frontend": False, # parameters for Noam