From 2cf1b56cb30657f26e972688ac413fb5810419ab Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 28 Dec 2021 10:38:22 +0800 Subject: [PATCH 1/3] Remove SOS from decoder. --- egs/librispeech/ASR/transducer_lstm/beam_search.py | 3 +-- egs/librispeech/ASR/transducer_lstm/decode.py | 4 +--- egs/librispeech/ASR/transducer_lstm/decoder.py | 4 ---- egs/librispeech/ASR/transducer_lstm/model.py | 6 ++---- egs/librispeech/ASR/transducer_lstm/train.py | 8 +------- 5 files changed, 5 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py index dfc22fcf8..f45d06ce9 100644 --- a/egs/librispeech/ASR/transducer_lstm/beam_search.py +++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py @@ -111,7 +111,6 @@ def beam_search( # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id - sos_id = model.decoder.sos_id device = model.device sos = torch.tensor([blank_id], device=device).reshape(1, 1) @@ -192,7 +191,7 @@ def beam_search( # Second, choose other labels for i, v in enumerate(log_prob.tolist()): - if i in (blank_id, sos_id): + if i == blank_id: continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index 18ae5234c..f22696613 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -155,7 +155,6 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.decoder_embedding_dim, blank_id=params.blank_id, - sos_id=params.sos_id, num_layers=params.num_decoder_layers, hidden_dim=params.decoder_hidden_dim, output_dim=params.encoder_out_dim, @@ -393,9 +392,8 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py index 2f6bf4c07..7b529ac19 100644 --- a/egs/librispeech/ASR/transducer_lstm/decoder.py +++ b/egs/librispeech/ASR/transducer_lstm/decoder.py @@ -27,7 +27,6 @@ class Decoder(nn.Module): vocab_size: int, embedding_dim: int, blank_id: int, - sos_id: int, num_layers: int, hidden_dim: int, output_dim: int, @@ -42,8 +41,6 @@ class Decoder(nn.Module): Dimension of the input embedding. blank_id: The ID of the blank symbol. - sos_id: - The ID of the SOS symbol. num_layers: Number of LSTM layers. hidden_dim: @@ -71,7 +68,6 @@ class Decoder(nn.Module): dropout=rnn_dropout, ) self.blank_id = blank_id - self.sos_id = sos_id self.output_linear = nn.Linear(hidden_dim, output_dim) def forward( diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index cb9afd8a2..470c77b62 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -49,7 +49,7 @@ class Transducer(nn.Module): decoder: It is the prediction network in the paper. Its input shape is (N, U) and its output shape is (N, U, C). It should contain - two attributes: `blank_id` and `sos_id`. + one attributes `blank_id`. joiner: It has two inputs with shapes: (N, T, C) and (N, U, C). Its output shape is (N, T, U, C). Note that its output contains @@ -58,7 +58,6 @@ class Transducer(nn.Module): super().__init__() assert isinstance(encoder, EncoderInterface) assert hasattr(decoder, "blank_id") - assert hasattr(decoder, "sos_id") self.encoder = encoder self.decoder = decoder @@ -97,8 +96,7 @@ class Transducer(nn.Module): y_lens = row_splits[1:] - row_splits[:-1] blank_id = self.decoder.blank_id - sos_id = self.decoder.sos_id - sos_y = add_sos(y, sos_id=sos_id) + sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 62e9b5b12..f1d47b848 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -179,8 +179,6 @@ def get_params() -> AttributeDict: - num_decoder_layers: Number of decoder layer of transformer decoder. - - weight_decay: The weight_decay for the optimizer. - - warm_step: The warm_step for Noam optimizer. """ params = AttributeDict( @@ -206,7 +204,6 @@ def get_params() -> AttributeDict: "num_decoder_layers": 4, "decoder_hidden_dim": 512, # parameters for Noam - "weight_decay": 1e-6, "warm_step": 80000, # For the 100h subset, use 8k "env_info": get_env_info(), } @@ -232,7 +229,6 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.decoder_embedding_dim, blank_id=params.blank_id, - sos_id=params.sos_id, num_layers=params.num_decoder_layers, hidden_dim=params.decoder_hidden_dim, output_dim=params.encoder_out_dim, @@ -568,9 +564,8 @@ def run(rank, world_size, args): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) @@ -594,7 +589,6 @@ def run(rank, world_size, args): model_size=params.encoder_hidden_size, factor=params.lr_factor, warm_step=params.warm_step, - weight_decay=params.weight_decay, ) if checkpoints and "optimizer" in checkpoints: From ec083e93d8fd4018a253ffe5e1fb488dfd17afae Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 28 Dec 2021 10:49:50 +0800 Subject: [PATCH 2/3] Use a stateless decoder. --- egs/librispeech/ASR/transducer/decoder.py | 2 +- .../ASR/transducer_lstm/beam_search.py | 220 +----------------- egs/librispeech/ASR/transducer_lstm/decode.py | 22 +- .../ASR/transducer_lstm/decoder.py | 95 ++++---- egs/librispeech/ASR/transducer_lstm/train.py | 30 +-- 5 files changed, 77 insertions(+), 292 deletions(-) mode change 100644 => 120000 egs/librispeech/ASR/transducer_lstm/beam_search.py diff --git a/egs/librispeech/ASR/transducer/decoder.py b/egs/librispeech/ASR/transducer/decoder.py index 7b529ac19..4f4663aef 100644 --- a/egs/librispeech/ASR/transducer/decoder.py +++ b/egs/librispeech/ASR/transducer/decoder.py @@ -78,7 +78,7 @@ class Decoder(nn.Module): """ Args: y: - A 2-D tensor of shape (N, U) with BOS prepended. + A 2-D tensor of shape (N, U) with blank prepended. states: A tuple of two tensors containing the states information of LSTM layers in this decoder. diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py deleted file mode 100644 index f45d06ce9..000000000 --- a/egs/librispeech/ASR/transducer_lstm/beam_search.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple - -import torch -from model import Transducer - - -def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: - """ - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - device = model.device - - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out, (h, c) = model.decoder(sos) - T = encoder_out.size(1) - t = 0 - hyp = [] - - sym_per_frame = 0 - sym_per_utt = 0 - - max_sym_per_utt = 1000 - max_sym_per_frame = 3 - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - logits = model.joiner(current_encoder_out, decoder_out) - # logits is (1, 1, 1, vocab_size) - - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - # TODO: Use logits.argmax() - y = log_prob.argmax() - if y != blank_id: - hyp.append(y.item()) - y = y.reshape(1, 1) - decoder_out, (h, c) = model.decoder(y, (h, c)) - - sym_per_utt += 1 - sym_per_frame += 1 - - if y == blank_id or sym_per_frame > max_sym_per_frame: - sym_per_frame = 0 - t += 1 - - return hyp - - -@dataclass -class Hypothesis: - ys: List[int] # the predicted sequences so far - log_prob: float # The log prob of ys - - # Optional decoder state. We assume it is LSTM for now, - # so the state is a tuple (h, c) - decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - -def beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 5, -) -> List[int]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - device = model.device - - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out, (h, c) = model.decoder(sos) - T = encoder_out.size(1) - t = 0 - B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)] - max_u = 20000 # terminate after this number of steps - u = 0 - - cache: Dict[ - str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - ] = {} - - while t < T and u < max_u: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - A = B - B = [] - # for hyp in A: - # for h in A: - # if h.ys == hyp.ys[:-1]: - # # update the score of hyp - # decoder_input = torch.tensor( - # [h.ys[-1]], device=device - # ).reshape(1, 1) - # decoder_out, _ = model.decoder( - # decoder_input, h.decoder_state - # ) - # logits = model.joiner(current_encoder_out, decoder_out) - # log_prob = logits.log_softmax(dim=-1) - # log_prob = log_prob.squeeze() - # hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item() - - while u < max_u: - y_star = max(A, key=lambda hyp: hyp.log_prob) - A.remove(y_star) - - # Note: y_star.ys is unhashable, i.e., cannot be used - # as a key into a dict - cached_key = "_".join(map(str, y_star.ys)) - - if cached_key not in cache: - decoder_input = torch.tensor( - [y_star.ys[-1]], device=device - ).reshape(1, 1) - - decoder_out, decoder_state = model.decoder( - decoder_input, - y_star.decoder_state, - ) - cache[cached_key] = (decoder_out, decoder_state) - else: - decoder_out, decoder_state = cache[cached_key] - - logits = model.joiner(current_encoder_out, decoder_out) - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - - # If we choose blank here, add the new hypothesis to B. - # Otherwise, add the new hypothesis to A - - # First, choose blank - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() - - # ys[:] returns a copy of ys - new_y_star = Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - # Caution: Use y_star.decoder_state here - decoder_state=y_star.decoder_state, - ) - B.append(new_y_star) - - # Second, choose other labels - for i, v in enumerate(log_prob.tolist()): - if i == blank_id: - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - decoder_state=decoder_state, - ) - A.append(new_hyp) - u += 1 - # check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = max(A, key=lambda hyp: hyp.log_prob) - B = sorted( - [hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob], - key=lambda hyp: hyp.log_prob, - reverse=True, - ) - if len(B) >= beam: - B = B[:beam] - break - t += 1 - best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:])) - ys = best_hyp.ys[1:] # [1:] to remove the blank - return ys diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py new file mode 120000 index 000000000..08cb32ef7 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py @@ -0,0 +1 @@ +../transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index f22696613..fee2a10ad 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -114,6 +114,14 @@ def get_parser(): help="Used only when --decoding-method is beam_search", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + return parser @@ -124,14 +132,10 @@ def get_params() -> AttributeDict: "feature_dim": 80, "encoder_out_dim": 512, "subsampling_factor": 4, - "encoder_hidden_size": 1024, - "num_encoder_layers": 4, + "encoder_hidden_size": 2048, + "num_encoder_layers": 6, "proj_size": 512, "vgg_frontend": False, - # decoder params - "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, - "decoder_hidden_dim": 512, "env_info": get_env_info(), } ) @@ -153,11 +157,9 @@ def get_encoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict): decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.decoder_embedding_dim, + embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, - num_layers=params.num_decoder_layers, - hidden_dim=params.decoder_hidden_dim, - output_dim=params.encoder_out_dim, + context_size=params.context_size, ) return decoder diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py index 7b529ac19..dca084477 100644 --- a/egs/librispeech/ASR/transducer_lstm/decoder.py +++ b/egs/librispeech/ASR/transducer_lstm/decoder.py @@ -14,24 +14,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple - import torch import torch.nn as nn +import torch.nn.functional as F -# TODO(fangjun): Support switching between LSTM and GRU class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + def __init__( self, vocab_size: int, embedding_dim: int, blank_id: int, - num_layers: int, - hidden_dim: int, - output_dim: int, - embedding_dropout: float = 0.0, - rnn_dropout: float = 0.0, + context_size: int, ): """ Args: @@ -41,16 +47,9 @@ class Decoder(nn.Module): Dimension of the input embedding. blank_id: The ID of the blank symbol. - num_layers: - Number of LSTM layers. - hidden_dim: - Hidden dimension of LSTM layers. - output_dim: - Output dimension of the decoder. - embedding_dropout: - Dropout rate for the embedding layer. - rnn_dropout: - Dropout for LSTM layers. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() self.embedding = nn.Embedding( @@ -58,40 +57,42 @@ class Decoder(nn.Module): embedding_dim=embedding_dim, padding_idx=blank_id, ) - self.embedding_dropout = nn.Dropout(embedding_dropout) - # TODO(fangjun): Use layer normalized LSTM - self.rnn = nn.LSTM( - input_size=embedding_dim, - hidden_size=hidden_dim, - num_layers=num_layers, - batch_first=True, - dropout=rnn_dropout, - ) self.blank_id = blank_id - self.output_linear = nn.Linear(hidden_dim, output_dim) - def forward( - self, - y: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + assert context_size >= 1, context_size + self.context_size = context_size + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=embedding_dim, + out_channels=embedding_dim, + kernel_size=context_size, + padding=0, + groups=embedding_dim, + bias=False, + ) + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ Args: y: - A 2-D tensor of shape (N, U) with BOS prepended. - states: - A tuple of two tensors containing the states information of - LSTM layers in this decoder. + A 2-D tensor of shape (N, U) with blank prepended. + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. Returns: - Return a tuple containing: - - - rnn_output, a tensor of shape (N, U, C) - - (h, c), containing the state information for LSTM layers. - Both are of shape (num_layers, N, C) + Return a tensor of shape (N, U, embedding_dim). """ embeding_out = self.embedding(y) - embeding_out = self.embedding_dropout(embeding_out) - rnn_out, (h, c) = self.rnn(embeding_out, states) - out = self.output_linear(rnn_out) - - return out, (h, c) + if self.context_size > 1: + embeding_out = embeding_out.permute(0, 2, 1) + if need_pad is True: + embeding_out = F.pad( + embeding_out, pad=(self.context_size - 1, 0) + ) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embeding_out.size(-1) == self.context_size + embeding_out = self.conv(embeding_out) + embeding_out = embeding_out.permute(0, 2, 1) + return embeding_out diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index f1d47b848..d931b2472 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -131,6 +131,14 @@ def get_parser(): help="The lr_factor for Noam optimizer", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + return parser @@ -172,9 +180,6 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - use_feat_batchnorm: Whether to do batch normalization for the - input features. - - attention_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. @@ -195,14 +200,10 @@ def get_params() -> AttributeDict: "feature_dim": 80, "encoder_out_dim": 512, "subsampling_factor": 4, - "encoder_hidden_size": 1024, - "num_encoder_layers": 4, + "encoder_hidden_size": 2048, + "num_encoder_layers": 6, "proj_size": 512, "vgg_frontend": False, - # decoder params - "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, - "decoder_hidden_dim": 512, # parameters for Noam "warm_step": 80000, # For the 100h subset, use 8k "env_info": get_env_info(), @@ -227,12 +228,11 @@ def get_encoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict): decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.decoder_embedding_dim, + embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, - num_layers=params.num_decoder_layers, - hidden_dim=params.decoder_hidden_dim, - output_dim=params.encoder_out_dim, + context_size=params.context_size, ) + return decoder @@ -573,11 +573,11 @@ def run(rank, world_size, args): logging.info("About to create model") model = get_transducer_model(params) - checkpoints = load_checkpoint_if_available(params=params, model=model) - num_param = sum([p.numel() for p in model.parameters() if p.requires_grad]) logging.info(f"Number of model parameters: {num_param}") + checkpoints = load_checkpoint_if_available(params=params, model=model) + model.to(device) if world_size > 1: logging.info("Using DDP") From 3c89734b79357508c0274b04895f64cc7d65a446 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 28 Dec 2021 11:09:43 +0800 Subject: [PATCH 3/3] Use similar number of parameters as conformer encoder. --- egs/librispeech/ASR/README.md | 2 +- egs/librispeech/ASR/transducer_lstm/decode.py | 28 ++++-- .../ASR/transducer_lstm/decoder.py | 99 +------------------ egs/librispeech/ASR/transducer_lstm/model.py | 6 +- egs/librispeech/ASR/transducer_lstm/train.py | 4 +- 5 files changed, 28 insertions(+), 111 deletions(-) mode change 100644 => 120000 egs/librispeech/ASR/transducer_lstm/decoder.py diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index c8ee98d7d..9fd657eec 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -13,7 +13,7 @@ The following table lists the differences among them. |------------------------|-----------|--------------------| | `transducer` | Conformer | LSTM | | `transducer_stateless` | Conformer | Embedding + Conv1d | -| `transducer_lstm ` | LSTM | LSTM | +| `transducer_lstm ` | LSTM | Embedding + Conv1d | The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index fee2a10ad..29d11c7a6 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -32,7 +32,7 @@ Usage: --exp-dir ./transducer_lstm/exp \ --max-duration 100 \ --decoding-method beam_search \ - --beam-size 8 + --beam-size 4 """ @@ -70,14 +70,14 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=77, + default=29, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=55, + default=13, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--beam-size", type=int, - default=5, + default=4, help="Used only when --decoding-method is beam_search", ) @@ -122,6 +122,13 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=3, + help="Maximum number of symbols per frame", + ) + return parser @@ -132,8 +139,8 @@ def get_params() -> AttributeDict: "feature_dim": 80, "encoder_out_dim": 512, "subsampling_factor": 4, - "encoder_hidden_size": 2048, - "num_encoder_layers": 6, + "encoder_hidden_size": 1024, + "num_encoder_layers": 7, "proj_size": 512, "vgg_frontend": False, "env_info": get_env_info(), @@ -237,7 +244,11 @@ def decode_one_batch( encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": - hyp = greedy_search(model=model, encoder_out=encoder_out_i) + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) elif params.decoding_method == "beam_search": hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size @@ -381,6 +392,9 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.decoding_method == "beam_search": params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py deleted file mode 100644 index dca084477..000000000 --- a/egs/librispeech/ASR/transducer_lstm/decoder.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - embedding_dim: int, - blank_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - embedding_dim: - Dimension of the input embedding. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim, - padding_idx=blank_id, - ) - self.blank_id = blank_id - - assert context_size >= 1, context_size - self.context_size = context_size - if context_size > 1: - self.conv = nn.Conv1d( - in_channels=embedding_dim, - out_channels=embedding_dim, - kernel_size=context_size, - padding=0, - groups=embedding_dim, - bias=False, - ) - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U) with blank prepended. - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, embedding_dim). - """ - embeding_out = self.embedding(y) - if self.context_size > 1: - embeding_out = embeding_out.permute(0, 2, 1) - if need_pad is True: - embeding_out = F.pad( - embeding_out, pad=(self.context_size - 1, 0) - ) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embeding_out.size(-1) == self.context_size - embeding_out = self.conv(embeding_out) - embeding_out = embeding_out.permute(0, 2, 1) - return embeding_out diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py new file mode 120000 index 000000000..eada91097 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/decoder.py @@ -0,0 +1 @@ +../transducer_stateless/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index 470c77b62..2f0f9a183 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -49,14 +49,14 @@ class Transducer(nn.Module): decoder: It is the prediction network in the paper. Its input shape is (N, U) and its output shape is (N, U, C). It should contain - one attributes `blank_id`. + one attribute: `blank_id`. joiner: It has two inputs with shapes: (N, T, C) and (N, U, C). Its output shape is (N, T, U, C). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. """ super().__init__() - assert isinstance(encoder, EncoderInterface) + assert isinstance(encoder, EncoderInterface), type(encoder) assert hasattr(decoder, "blank_id") self.encoder = encoder @@ -100,7 +100,7 @@ class Transducer(nn.Module): sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - decoder_out, _ = self.decoder(sos_y_padded) + decoder_out = self.decoder(sos_y_padded) logits = self.joiner(encoder_out, decoder_out) diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index d931b2472..d63d7aebd 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -200,8 +200,8 @@ def get_params() -> AttributeDict: "feature_dim": 80, "encoder_out_dim": 512, "subsampling_factor": 4, - "encoder_hidden_size": 2048, - "num_encoder_layers": 6, + "encoder_hidden_size": 1024, + "num_encoder_layers": 7, "proj_size": 512, "vgg_frontend": False, # parameters for Noam