diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index c8ee98d7d..9fd657eec 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -13,7 +13,7 @@ The following table lists the differences among them. |------------------------|-----------|--------------------| | `transducer` | Conformer | LSTM | | `transducer_stateless` | Conformer | Embedding + Conv1d | -| `transducer_lstm ` | LSTM | LSTM | +| `transducer_lstm ` | LSTM | Embedding + Conv1d | The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). diff --git a/egs/librispeech/ASR/transducer/decoder.py b/egs/librispeech/ASR/transducer/decoder.py index 7b529ac19..4f4663aef 100644 --- a/egs/librispeech/ASR/transducer/decoder.py +++ b/egs/librispeech/ASR/transducer/decoder.py @@ -78,7 +78,7 @@ class Decoder(nn.Module): """ Args: y: - A 2-D tensor of shape (N, U) with BOS prepended. + A 2-D tensor of shape (N, U) with blank prepended. states: A tuple of two tensors containing the states information of LSTM layers in this decoder. diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py deleted file mode 100644 index dfc22fcf8..000000000 --- a/egs/librispeech/ASR/transducer_lstm/beam_search.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple - -import torch -from model import Transducer - - -def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: - """ - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - device = model.device - - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out, (h, c) = model.decoder(sos) - T = encoder_out.size(1) - t = 0 - hyp = [] - - sym_per_frame = 0 - sym_per_utt = 0 - - max_sym_per_utt = 1000 - max_sym_per_frame = 3 - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - logits = model.joiner(current_encoder_out, decoder_out) - # logits is (1, 1, 1, vocab_size) - - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - # TODO: Use logits.argmax() - y = log_prob.argmax() - if y != blank_id: - hyp.append(y.item()) - y = y.reshape(1, 1) - decoder_out, (h, c) = model.decoder(y, (h, c)) - - sym_per_utt += 1 - sym_per_frame += 1 - - if y == blank_id or sym_per_frame > max_sym_per_frame: - sym_per_frame = 0 - t += 1 - - return hyp - - -@dataclass -class Hypothesis: - ys: List[int] # the predicted sequences so far - log_prob: float # The log prob of ys - - # Optional decoder state. We assume it is LSTM for now, - # so the state is a tuple (h, c) - decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - -def beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 5, -) -> List[int]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - sos_id = model.decoder.sos_id - device = model.device - - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out, (h, c) = model.decoder(sos) - T = encoder_out.size(1) - t = 0 - B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)] - max_u = 20000 # terminate after this number of steps - u = 0 - - cache: Dict[ - str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - ] = {} - - while t < T and u < max_u: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - A = B - B = [] - # for hyp in A: - # for h in A: - # if h.ys == hyp.ys[:-1]: - # # update the score of hyp - # decoder_input = torch.tensor( - # [h.ys[-1]], device=device - # ).reshape(1, 1) - # decoder_out, _ = model.decoder( - # decoder_input, h.decoder_state - # ) - # logits = model.joiner(current_encoder_out, decoder_out) - # log_prob = logits.log_softmax(dim=-1) - # log_prob = log_prob.squeeze() - # hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item() - - while u < max_u: - y_star = max(A, key=lambda hyp: hyp.log_prob) - A.remove(y_star) - - # Note: y_star.ys is unhashable, i.e., cannot be used - # as a key into a dict - cached_key = "_".join(map(str, y_star.ys)) - - if cached_key not in cache: - decoder_input = torch.tensor( - [y_star.ys[-1]], device=device - ).reshape(1, 1) - - decoder_out, decoder_state = model.decoder( - decoder_input, - y_star.decoder_state, - ) - cache[cached_key] = (decoder_out, decoder_state) - else: - decoder_out, decoder_state = cache[cached_key] - - logits = model.joiner(current_encoder_out, decoder_out) - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - - # If we choose blank here, add the new hypothesis to B. - # Otherwise, add the new hypothesis to A - - # First, choose blank - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() - - # ys[:] returns a copy of ys - new_y_star = Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - # Caution: Use y_star.decoder_state here - decoder_state=y_star.decoder_state, - ) - B.append(new_y_star) - - # Second, choose other labels - for i, v in enumerate(log_prob.tolist()): - if i in (blank_id, sos_id): - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - decoder_state=decoder_state, - ) - A.append(new_hyp) - u += 1 - # check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = max(A, key=lambda hyp: hyp.log_prob) - B = sorted( - [hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob], - key=lambda hyp: hyp.log_prob, - reverse=True, - ) - if len(B) >= beam: - B = B[:beam] - break - t += 1 - best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:])) - ys = best_hyp.ys[1:] # [1:] to remove the blank - return ys diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py new file mode 120000 index 000000000..08cb32ef7 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py @@ -0,0 +1 @@ +../transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index 18ae5234c..29d11c7a6 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -32,7 +32,7 @@ Usage: --exp-dir ./transducer_lstm/exp \ --max-duration 100 \ --decoding-method beam_search \ - --beam-size 8 + --beam-size 4 """ @@ -70,14 +70,14 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=77, + default=29, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=55, + default=13, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -110,10 +110,25 @@ def get_parser(): parser.add_argument( "--beam-size", type=int, - default=5, + default=4, help="Used only when --decoding-method is beam_search", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=3, + help="Maximum number of symbols per frame", + ) + return parser @@ -125,13 +140,9 @@ def get_params() -> AttributeDict: "encoder_out_dim": 512, "subsampling_factor": 4, "encoder_hidden_size": 1024, - "num_encoder_layers": 4, + "num_encoder_layers": 7, "proj_size": 512, "vgg_frontend": False, - # decoder params - "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, - "decoder_hidden_dim": 512, "env_info": get_env_info(), } ) @@ -153,12 +164,9 @@ def get_encoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict): decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.decoder_embedding_dim, + embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, - sos_id=params.sos_id, - num_layers=params.num_decoder_layers, - hidden_dim=params.decoder_hidden_dim, - output_dim=params.encoder_out_dim, + context_size=params.context_size, ) return decoder @@ -236,7 +244,11 @@ def decode_one_batch( encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": - hyp = greedy_search(model=model, encoder_out=encoder_out_i) + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) elif params.decoding_method == "beam_search": hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size @@ -380,6 +392,9 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.decoding_method == "beam_search": params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -393,9 +408,8 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py deleted file mode 100644 index 2f6bf4c07..000000000 --- a/egs/librispeech/ASR/transducer_lstm/decoder.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Tuple - -import torch -import torch.nn as nn - - -# TODO(fangjun): Support switching between LSTM and GRU -class Decoder(nn.Module): - def __init__( - self, - vocab_size: int, - embedding_dim: int, - blank_id: int, - sos_id: int, - num_layers: int, - hidden_dim: int, - output_dim: int, - embedding_dropout: float = 0.0, - rnn_dropout: float = 0.0, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - embedding_dim: - Dimension of the input embedding. - blank_id: - The ID of the blank symbol. - sos_id: - The ID of the SOS symbol. - num_layers: - Number of LSTM layers. - hidden_dim: - Hidden dimension of LSTM layers. - output_dim: - Output dimension of the decoder. - embedding_dropout: - Dropout rate for the embedding layer. - rnn_dropout: - Dropout for LSTM layers. - """ - super().__init__() - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim, - padding_idx=blank_id, - ) - self.embedding_dropout = nn.Dropout(embedding_dropout) - # TODO(fangjun): Use layer normalized LSTM - self.rnn = nn.LSTM( - input_size=embedding_dim, - hidden_size=hidden_dim, - num_layers=num_layers, - batch_first=True, - dropout=rnn_dropout, - ) - self.blank_id = blank_id - self.sos_id = sos_id - self.output_linear = nn.Linear(hidden_dim, output_dim) - - def forward( - self, - y: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Args: - y: - A 2-D tensor of shape (N, U) with BOS prepended. - states: - A tuple of two tensors containing the states information of - LSTM layers in this decoder. - Returns: - Return a tuple containing: - - - rnn_output, a tensor of shape (N, U, C) - - (h, c), containing the state information for LSTM layers. - Both are of shape (num_layers, N, C) - """ - embeding_out = self.embedding(y) - embeding_out = self.embedding_dropout(embeding_out) - rnn_out, (h, c) = self.rnn(embeding_out, states) - out = self.output_linear(rnn_out) - - return out, (h, c) diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py new file mode 120000 index 000000000..eada91097 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/decoder.py @@ -0,0 +1 @@ +../transducer_stateless/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index cb9afd8a2..2f0f9a183 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -49,16 +49,15 @@ class Transducer(nn.Module): decoder: It is the prediction network in the paper. Its input shape is (N, U) and its output shape is (N, U, C). It should contain - two attributes: `blank_id` and `sos_id`. + one attribute: `blank_id`. joiner: It has two inputs with shapes: (N, T, C) and (N, U, C). Its output shape is (N, T, U, C). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. """ super().__init__() - assert isinstance(encoder, EncoderInterface) + assert isinstance(encoder, EncoderInterface), type(encoder) assert hasattr(decoder, "blank_id") - assert hasattr(decoder, "sos_id") self.encoder = encoder self.decoder = decoder @@ -97,12 +96,11 @@ class Transducer(nn.Module): y_lens = row_splits[1:] - row_splits[:-1] blank_id = self.decoder.blank_id - sos_id = self.decoder.sos_id - sos_y = add_sos(y, sos_id=sos_id) + sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - decoder_out, _ = self.decoder(sos_y_padded) + decoder_out = self.decoder(sos_y_padded) logits = self.joiner(encoder_out, decoder_out) diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 62e9b5b12..d63d7aebd 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -131,6 +131,14 @@ def get_parser(): help="The lr_factor for Noam optimizer", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + return parser @@ -172,15 +180,10 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - use_feat_batchnorm: Whether to do batch normalization for the - input features. - - attention_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. - - weight_decay: The weight_decay for the optimizer. - - warm_step: The warm_step for Noam optimizer. """ params = AttributeDict( @@ -198,15 +201,10 @@ def get_params() -> AttributeDict: "encoder_out_dim": 512, "subsampling_factor": 4, "encoder_hidden_size": 1024, - "num_encoder_layers": 4, + "num_encoder_layers": 7, "proj_size": 512, "vgg_frontend": False, - # decoder params - "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, - "decoder_hidden_dim": 512, # parameters for Noam - "weight_decay": 1e-6, "warm_step": 80000, # For the 100h subset, use 8k "env_info": get_env_info(), } @@ -230,13 +228,11 @@ def get_encoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict): decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.decoder_embedding_dim, + embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, - sos_id=params.sos_id, - num_layers=params.num_decoder_layers, - hidden_dim=params.decoder_hidden_dim, - output_dim=params.encoder_out_dim, + context_size=params.context_size, ) + return decoder @@ -568,9 +564,8 @@ def run(rank, world_size, args): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) @@ -578,11 +573,11 @@ def run(rank, world_size, args): logging.info("About to create model") model = get_transducer_model(params) - checkpoints = load_checkpoint_if_available(params=params, model=model) - num_param = sum([p.numel() for p in model.parameters() if p.requires_grad]) logging.info(f"Number of model parameters: {num_param}") + checkpoints = load_checkpoint_if_available(params=params, model=model) + model.to(device) if world_size > 1: logging.info("Using DDP") @@ -594,7 +589,6 @@ def run(rank, world_size, args): model_size=params.encoder_hidden_size, factor=params.lr_factor, warm_step=params.warm_step, - weight_decay=params.weight_decay, ) if checkpoints and "optimizer" in checkpoints: