diff --git a/egs/librispeech/ASR/transducer/decoder.py b/egs/librispeech/ASR/transducer/decoder.py index 7b529ac19..4f4663aef 100644 --- a/egs/librispeech/ASR/transducer/decoder.py +++ b/egs/librispeech/ASR/transducer/decoder.py @@ -78,7 +78,7 @@ class Decoder(nn.Module): """ Args: y: - A 2-D tensor of shape (N, U) with BOS prepended. + A 2-D tensor of shape (N, U) with blank prepended. states: A tuple of two tensors containing the states information of LSTM layers in this decoder. diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py deleted file mode 100644 index f45d06ce9..000000000 --- a/egs/librispeech/ASR/transducer_lstm/beam_search.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple - -import torch -from model import Transducer - - -def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: - """ - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - device = model.device - - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out, (h, c) = model.decoder(sos) - T = encoder_out.size(1) - t = 0 - hyp = [] - - sym_per_frame = 0 - sym_per_utt = 0 - - max_sym_per_utt = 1000 - max_sym_per_frame = 3 - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - logits = model.joiner(current_encoder_out, decoder_out) - # logits is (1, 1, 1, vocab_size) - - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - # TODO: Use logits.argmax() - y = log_prob.argmax() - if y != blank_id: - hyp.append(y.item()) - y = y.reshape(1, 1) - decoder_out, (h, c) = model.decoder(y, (h, c)) - - sym_per_utt += 1 - sym_per_frame += 1 - - if y == blank_id or sym_per_frame > max_sym_per_frame: - sym_per_frame = 0 - t += 1 - - return hyp - - -@dataclass -class Hypothesis: - ys: List[int] # the predicted sequences so far - log_prob: float # The log prob of ys - - # Optional decoder state. We assume it is LSTM for now, - # so the state is a tuple (h, c) - decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - -def beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 5, -) -> List[int]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - device = model.device - - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out, (h, c) = model.decoder(sos) - T = encoder_out.size(1) - t = 0 - B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)] - max_u = 20000 # terminate after this number of steps - u = 0 - - cache: Dict[ - str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - ] = {} - - while t < T and u < max_u: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - A = B - B = [] - # for hyp in A: - # for h in A: - # if h.ys == hyp.ys[:-1]: - # # update the score of hyp - # decoder_input = torch.tensor( - # [h.ys[-1]], device=device - # ).reshape(1, 1) - # decoder_out, _ = model.decoder( - # decoder_input, h.decoder_state - # ) - # logits = model.joiner(current_encoder_out, decoder_out) - # log_prob = logits.log_softmax(dim=-1) - # log_prob = log_prob.squeeze() - # hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item() - - while u < max_u: - y_star = max(A, key=lambda hyp: hyp.log_prob) - A.remove(y_star) - - # Note: y_star.ys is unhashable, i.e., cannot be used - # as a key into a dict - cached_key = "_".join(map(str, y_star.ys)) - - if cached_key not in cache: - decoder_input = torch.tensor( - [y_star.ys[-1]], device=device - ).reshape(1, 1) - - decoder_out, decoder_state = model.decoder( - decoder_input, - y_star.decoder_state, - ) - cache[cached_key] = (decoder_out, decoder_state) - else: - decoder_out, decoder_state = cache[cached_key] - - logits = model.joiner(current_encoder_out, decoder_out) - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - - # If we choose blank here, add the new hypothesis to B. - # Otherwise, add the new hypothesis to A - - # First, choose blank - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() - - # ys[:] returns a copy of ys - new_y_star = Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - # Caution: Use y_star.decoder_state here - decoder_state=y_star.decoder_state, - ) - B.append(new_y_star) - - # Second, choose other labels - for i, v in enumerate(log_prob.tolist()): - if i == blank_id: - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - decoder_state=decoder_state, - ) - A.append(new_hyp) - u += 1 - # check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = max(A, key=lambda hyp: hyp.log_prob) - B = sorted( - [hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob], - key=lambda hyp: hyp.log_prob, - reverse=True, - ) - if len(B) >= beam: - B = B[:beam] - break - t += 1 - best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:])) - ys = best_hyp.ys[1:] # [1:] to remove the blank - return ys diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py new file mode 120000 index 000000000..08cb32ef7 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py @@ -0,0 +1 @@ +../transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index f22696613..fee2a10ad 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -114,6 +114,14 @@ def get_parser(): help="Used only when --decoding-method is beam_search", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + return parser @@ -124,14 +132,10 @@ def get_params() -> AttributeDict: "feature_dim": 80, "encoder_out_dim": 512, "subsampling_factor": 4, - "encoder_hidden_size": 1024, - "num_encoder_layers": 4, + "encoder_hidden_size": 2048, + "num_encoder_layers": 6, "proj_size": 512, "vgg_frontend": False, - # decoder params - "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, - "decoder_hidden_dim": 512, "env_info": get_env_info(), } ) @@ -153,11 +157,9 @@ def get_encoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict): decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.decoder_embedding_dim, + embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, - num_layers=params.num_decoder_layers, - hidden_dim=params.decoder_hidden_dim, - output_dim=params.encoder_out_dim, + context_size=params.context_size, ) return decoder diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py index 7b529ac19..dca084477 100644 --- a/egs/librispeech/ASR/transducer_lstm/decoder.py +++ b/egs/librispeech/ASR/transducer_lstm/decoder.py @@ -14,24 +14,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple - import torch import torch.nn as nn +import torch.nn.functional as F -# TODO(fangjun): Support switching between LSTM and GRU class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + def __init__( self, vocab_size: int, embedding_dim: int, blank_id: int, - num_layers: int, - hidden_dim: int, - output_dim: int, - embedding_dropout: float = 0.0, - rnn_dropout: float = 0.0, + context_size: int, ): """ Args: @@ -41,16 +47,9 @@ class Decoder(nn.Module): Dimension of the input embedding. blank_id: The ID of the blank symbol. - num_layers: - Number of LSTM layers. - hidden_dim: - Hidden dimension of LSTM layers. - output_dim: - Output dimension of the decoder. - embedding_dropout: - Dropout rate for the embedding layer. - rnn_dropout: - Dropout for LSTM layers. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() self.embedding = nn.Embedding( @@ -58,40 +57,42 @@ class Decoder(nn.Module): embedding_dim=embedding_dim, padding_idx=blank_id, ) - self.embedding_dropout = nn.Dropout(embedding_dropout) - # TODO(fangjun): Use layer normalized LSTM - self.rnn = nn.LSTM( - input_size=embedding_dim, - hidden_size=hidden_dim, - num_layers=num_layers, - batch_first=True, - dropout=rnn_dropout, - ) self.blank_id = blank_id - self.output_linear = nn.Linear(hidden_dim, output_dim) - def forward( - self, - y: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + assert context_size >= 1, context_size + self.context_size = context_size + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=embedding_dim, + out_channels=embedding_dim, + kernel_size=context_size, + padding=0, + groups=embedding_dim, + bias=False, + ) + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ Args: y: - A 2-D tensor of shape (N, U) with BOS prepended. - states: - A tuple of two tensors containing the states information of - LSTM layers in this decoder. + A 2-D tensor of shape (N, U) with blank prepended. + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. Returns: - Return a tuple containing: - - - rnn_output, a tensor of shape (N, U, C) - - (h, c), containing the state information for LSTM layers. - Both are of shape (num_layers, N, C) + Return a tensor of shape (N, U, embedding_dim). """ embeding_out = self.embedding(y) - embeding_out = self.embedding_dropout(embeding_out) - rnn_out, (h, c) = self.rnn(embeding_out, states) - out = self.output_linear(rnn_out) - - return out, (h, c) + if self.context_size > 1: + embeding_out = embeding_out.permute(0, 2, 1) + if need_pad is True: + embeding_out = F.pad( + embeding_out, pad=(self.context_size - 1, 0) + ) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embeding_out.size(-1) == self.context_size + embeding_out = self.conv(embeding_out) + embeding_out = embeding_out.permute(0, 2, 1) + return embeding_out diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index f1d47b848..d931b2472 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -131,6 +131,14 @@ def get_parser(): help="The lr_factor for Noam optimizer", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + return parser @@ -172,9 +180,6 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - use_feat_batchnorm: Whether to do batch normalization for the - input features. - - attention_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. @@ -195,14 +200,10 @@ def get_params() -> AttributeDict: "feature_dim": 80, "encoder_out_dim": 512, "subsampling_factor": 4, - "encoder_hidden_size": 1024, - "num_encoder_layers": 4, + "encoder_hidden_size": 2048, + "num_encoder_layers": 6, "proj_size": 512, "vgg_frontend": False, - # decoder params - "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, - "decoder_hidden_dim": 512, # parameters for Noam "warm_step": 80000, # For the 100h subset, use 8k "env_info": get_env_info(), @@ -227,12 +228,11 @@ def get_encoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict): decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.decoder_embedding_dim, + embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, - num_layers=params.num_decoder_layers, - hidden_dim=params.decoder_hidden_dim, - output_dim=params.encoder_out_dim, + context_size=params.context_size, ) + return decoder @@ -573,11 +573,11 @@ def run(rank, world_size, args): logging.info("About to create model") model = get_transducer_model(params) - checkpoints = load_checkpoint_if_available(params=params, model=model) - num_param = sum([p.numel() for p in model.parameters() if p.requires_grad]) logging.info(f"Number of model parameters: {num_param}") + checkpoints = load_checkpoint_if_available(params=params, model=model) + model.to(device) if world_size > 1: logging.info("Using DDP")