mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Use similar number of parameters as conformer encoder.
This commit is contained in:
parent
ec083e93d8
commit
3c89734b79
@ -13,7 +13,7 @@ The following table lists the differences among them.
|
||||
|------------------------|-----------|--------------------|
|
||||
| `transducer` | Conformer | LSTM |
|
||||
| `transducer_stateless` | Conformer | Embedding + Conv1d |
|
||||
| `transducer_lstm ` | LSTM | LSTM |
|
||||
| `transducer_lstm ` | LSTM | Embedding + Conv1d |
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
|
@ -32,7 +32,7 @@ Usage:
|
||||
--exp-dir ./transducer_lstm/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 8
|
||||
--beam-size 4
|
||||
"""
|
||||
|
||||
|
||||
@ -70,14 +70,14 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=77,
|
||||
default=29,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=55,
|
||||
default=13,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
@ -110,7 +110,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=5,
|
||||
default=4,
|
||||
help="Used only when --decoding-method is beam_search",
|
||||
)
|
||||
|
||||
@ -122,6 +122,13 @@ def get_parser():
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of symbols per frame",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -132,8 +139,8 @@ def get_params() -> AttributeDict:
|
||||
"feature_dim": 80,
|
||||
"encoder_out_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"encoder_hidden_size": 2048,
|
||||
"num_encoder_layers": 6,
|
||||
"encoder_hidden_size": 1024,
|
||||
"num_encoder_layers": 7,
|
||||
"proj_size": 512,
|
||||
"vgg_frontend": False,
|
||||
"env_info": get_env_info(),
|
||||
@ -237,7 +244,11 @@ def decode_one_batch(
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
@ -381,6 +392,9 @@ def main():
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if params.decoding_method == "beam_search":
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
@ -1,98 +0,0 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class modifies the stateless decoder from the following paper:
|
||||
|
||||
RNN-transducer with stateless prediction network
|
||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
|
||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||
network. Different from the above paper, it adds an extra Conv1d
|
||||
right after the embedding layer.
|
||||
|
||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embedding_dim: int,
|
||||
blank_id: int,
|
||||
context_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_size:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
embedding_dim:
|
||||
Dimension of the input embedding.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
"""
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
if context_size > 1:
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=embedding_dim,
|
||||
out_channels=embedding_dim,
|
||||
kernel_size=context_size,
|
||||
padding=0,
|
||||
groups=embedding_dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U) with blank prepended.
|
||||
need_pad:
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, embedding_dim).
|
||||
"""
|
||||
embeding_out = self.embedding(y)
|
||||
if self.context_size > 1:
|
||||
embeding_out = embeding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embeding_out = F.pad(
|
||||
embeding_out, pad=(self.context_size - 1, 0)
|
||||
)
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
assert embeding_out.size(-1) == self.context_size
|
||||
embeding_out = self.conv(embeding_out)
|
||||
embeding_out = embeding_out.permute(0, 2, 1)
|
||||
return embeding_out
|
1
egs/librispeech/ASR/transducer_lstm/decoder.py
Symbolic link
1
egs/librispeech/ASR/transducer_lstm/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../transducer_stateless/decoder.py
|
@ -49,14 +49,14 @@ class Transducer(nn.Module):
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, C). It should contain
|
||||
one attributes `blank_id`.
|
||||
one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||
output shape is (N, T, U, C). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface)
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.encoder = encoder
|
||||
@ -100,7 +100,7 @@ class Transducer(nn.Module):
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
decoder_out, _ = self.decoder(sos_y_padded)
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
logits = self.joiner(encoder_out, decoder_out)
|
||||
|
||||
|
@ -200,8 +200,8 @@ def get_params() -> AttributeDict:
|
||||
"feature_dim": 80,
|
||||
"encoder_out_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"encoder_hidden_size": 2048,
|
||||
"num_encoder_layers": 6,
|
||||
"encoder_hidden_size": 1024,
|
||||
"num_encoder_layers": 7,
|
||||
"proj_size": 512,
|
||||
"vgg_frontend": False,
|
||||
# parameters for Noam
|
||||
|
Loading…
x
Reference in New Issue
Block a user