Use similar number of parameters as conformer encoder.

This commit is contained in:
Fangjun Kuang 2021-12-28 11:09:43 +08:00
parent ec083e93d8
commit 3c89734b79
5 changed files with 28 additions and 111 deletions

View File

@ -13,7 +13,7 @@ The following table lists the differences among them.
|------------------------|-----------|--------------------| |------------------------|-----------|--------------------|
| `transducer` | Conformer | LSTM | | `transducer` | Conformer | LSTM |
| `transducer_stateless` | Conformer | Embedding + Conv1d | | `transducer_stateless` | Conformer | Embedding + Conv1d |
| `transducer_lstm ` | LSTM | LSTM | | `transducer_lstm ` | LSTM | Embedding + Conv1d |
The decoder in `transducer_stateless` is modified from the paper The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -32,7 +32,7 @@ Usage:
--exp-dir ./transducer_lstm/exp \ --exp-dir ./transducer_lstm/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 8 --beam-size 4
""" """
@ -70,14 +70,14 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=77, default=29,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=55, default=13,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",
@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
default=5, default=4,
help="Used only when --decoding-method is beam_search", help="Used only when --decoding-method is beam_search",
) )
@ -122,6 +122,13 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame",
)
return parser return parser
@ -132,8 +139,8 @@ def get_params() -> AttributeDict:
"feature_dim": 80, "feature_dim": 80,
"encoder_out_dim": 512, "encoder_out_dim": 512,
"subsampling_factor": 4, "subsampling_factor": 4,
"encoder_hidden_size": 2048, "encoder_hidden_size": 1024,
"num_encoder_layers": 6, "num_encoder_layers": 7,
"proj_size": 512, "proj_size": 512,
"vgg_frontend": False, "vgg_frontend": False,
"env_info": get_env_info(), "env_info": get_env_info(),
@ -237,7 +244,11 @@ def decode_one_batch(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on # fmt: on
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i) hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
@ -381,6 +392,9 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search": if params.decoding_method == "beam_search":
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")

View File

@ -1,98 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
embedding_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
embedding_dim:
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=embedding_dim,
out_channels=embedding_dim,
kernel_size=context_size,
padding=0,
groups=embedding_dim,
bias=False,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U) with blank prepended.
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embeding_out = self.embedding(y)
if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1)
if need_pad is True:
embeding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embeding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out)
embeding_out = embeding_out.permute(0, 2, 1)
return embeding_out

View File

@ -0,0 +1 @@
../transducer_stateless/decoder.py

View File

@ -49,14 +49,14 @@ class Transducer(nn.Module):
decoder: decoder:
It is the prediction network in the paper. Its input shape It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain is (N, U) and its output shape is (N, U, C). It should contain
one attributes `blank_id`. one attribute: `blank_id`.
joiner: joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax. unnormalized probs, i.e., not processed by log-softmax.
""" """
super().__init__() super().__init__()
assert isinstance(encoder, EncoderInterface) assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id") assert hasattr(decoder, "blank_id")
self.encoder = encoder self.encoder = encoder
@ -100,7 +100,7 @@ class Transducer(nn.Module):
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
decoder_out, _ = self.decoder(sos_y_padded) decoder_out = self.decoder(sos_y_padded)
logits = self.joiner(encoder_out, decoder_out) logits = self.joiner(encoder_out, decoder_out)

View File

@ -200,8 +200,8 @@ def get_params() -> AttributeDict:
"feature_dim": 80, "feature_dim": 80,
"encoder_out_dim": 512, "encoder_out_dim": 512,
"subsampling_factor": 4, "subsampling_factor": 4,
"encoder_hidden_size": 2048, "encoder_hidden_size": 1024,
"num_encoder_layers": 6, "num_encoder_layers": 7,
"proj_size": 512, "proj_size": 512,
"vgg_frontend": False, "vgg_frontend": False,
# parameters for Noam # parameters for Noam