diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index de9d6d50a..4f28dba9b 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -14,7 +14,7 @@ The following table lists the differences among them. | `transducer` | Conformer | LSTM | | | `transducer_stateless` | Conformer | Embedding + Conv1d | Using optimized_transducer from computing RNN-T loss | | `transducer_stateless2` | Conformer | Embedding + Conv1d | Using torchaudio for computing RNN-T loss | -| `transducer_lstm` | LSTM | LSTM | | +| `transducer_lstm` | LSTM | Embedding + Conv1d | Using torchaudio for computing RNN-T loss | | `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data | | `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py deleted file mode 100644 index 3531a9633..000000000 --- a/egs/librispeech/ASR/transducer_lstm/beam_search.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple - -import torch -from model import Transducer - - -def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: - """ - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - device = model.device - - sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( - 1, 1 - ) - decoder_out, (h, c) = model.decoder(sos) - T = encoder_out.size(1) - t = 0 - hyp = [] - - sym_per_frame = 0 - sym_per_utt = 0 - - max_sym_per_utt = 1000 - max_sym_per_frame = 3 - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - logits = model.joiner(current_encoder_out, decoder_out) - # logits is (1, 1, 1, vocab_size) - - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - # TODO: Use logits.argmax() - y = log_prob.argmax() - if y != blank_id: - hyp.append(y.item()) - y = y.reshape(1, 1) - decoder_out, (h, c) = model.decoder(y, (h, c)) - - sym_per_utt += 1 - sym_per_frame += 1 - - if y == blank_id or sym_per_frame > max_sym_per_frame: - sym_per_frame = 0 - t += 1 - - return hyp - - -@dataclass -class Hypothesis: - ys: List[int] # the predicted sequences so far - log_prob: float # The log prob of ys - - # Optional decoder state. We assume it is LSTM for now, - # so the state is a tuple (h, c) - decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - -def beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 5, -) -> List[int]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - sos_id = model.decoder.sos_id - device = model.device - - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out, (h, c) = model.decoder(sos) - T = encoder_out.size(1) - t = 0 - B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)] - max_u = 20000 # terminate after this number of steps - u = 0 - - cache: Dict[ - str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - ] = {} - - while t < T and u < max_u: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :] - # fmt: on - A = B - B = [] - # for hyp in A: - # for h in A: - # if h.ys == hyp.ys[:-1]: - # # update the score of hyp - # decoder_input = torch.tensor( - # [h.ys[-1]], device=device - # ).reshape(1, 1) - # decoder_out, _ = model.decoder( - # decoder_input, h.decoder_state - # ) - # logits = model.joiner(current_encoder_out, decoder_out) - # log_prob = logits.log_softmax(dim=-1) - # log_prob = log_prob.squeeze() - # hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item() - - while u < max_u: - y_star = max(A, key=lambda hyp: hyp.log_prob) - A.remove(y_star) - - # Note: y_star.ys is unhashable, i.e., cannot be used - # as a key into a dict - cached_key = "_".join(map(str, y_star.ys)) - - if cached_key not in cache: - decoder_input = torch.tensor( - [y_star.ys[-1]], device=device - ).reshape(1, 1) - - decoder_out, decoder_state = model.decoder( - decoder_input, - y_star.decoder_state, - ) - cache[cached_key] = (decoder_out, decoder_state) - else: - decoder_out, decoder_state = cache[cached_key] - - logits = model.joiner(current_encoder_out, decoder_out) - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - - # If we choose blank here, add the new hypothesis to B. - # Otherwise, add the new hypothesis to A - - # First, choose blank - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() - - # ys[:] returns a copy of ys - new_y_star = Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - # Caution: Use y_star.decoder_state here - decoder_state=y_star.decoder_state, - ) - B.append(new_y_star) - - # Second, choose other labels - for i, v in enumerate(log_prob.tolist()): - if i in (blank_id, sos_id): - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - decoder_state=decoder_state, - ) - A.append(new_hyp) - u += 1 - # check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = max(A, key=lambda hyp: hyp.log_prob) - B = sorted( - [hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob], - key=lambda hyp: hyp.log_prob, - reverse=True, - ) - if len(B) >= beam: - B = B[:beam] - break - t += 1 - best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:])) - ys = best_hyp.ys[1:] # [1:] to remove the blank - return ys diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py new file mode 120000 index 000000000..08cb32ef7 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py @@ -0,0 +1 @@ +../transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index 18ae5234c..03d1b840b 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -46,14 +46,15 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search -from decoder import Decoder -from encoder import LstmEncoder -from joiner import Joiner -from model import Transducer +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info from icefall.utils import ( AttributeDict, setup_logger, @@ -104,6 +105,7 @@ def get_parser(): help="""Possible values are: - greedy_search - beam_search + - modified_beam_search """, ) @@ -114,76 +116,25 @@ def get_parser(): help="Used only when --decoding-method is beam_search", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "encoder_hidden_size": 1024, - "num_encoder_layers": 4, - "proj_size": 512, - "vgg_frontend": False, - # decoder params - "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, - "decoder_hidden_dim": 512, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict): - encoder = LstmEncoder( - num_features=params.feature_dim, - hidden_size=params.encoder_hidden_size, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict): - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.decoder_embedding_dim, - blank_id=params.blank_id, - sos_id=params.sos_id, - num_layers=params.num_decoder_layers, - hidden_dim=params.decoder_hidden_dim, - output_dim=params.encoder_out_dim, - ) - return decoder - - -def get_joiner_model(params: AttributeDict): - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict): - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -228,24 +179,47 @@ def decode_one_batch( encoder_out, encoder_out_lens = model.encoder( x=feature, x_lens=feature_lens ) - hyps = [] - batch_size = encoder_out.size(0) + hyp_list: List[List[int]] = [] - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search(model=model, encoder_out=encoder_out_i) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) + if ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + elif params.decoding_method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: + batch_size = encoder_out.size(0) + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyp_list.append(hyp) + + hyps = [sp.decode(hyp).split() for hyp in hyp_list] if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -393,9 +367,8 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/librispeech/ASR/transducer_lstm/decoder.py b/egs/librispeech/ASR/transducer_lstm/decoder.py index 2f6bf4c07..b82fed37b 100644 --- a/egs/librispeech/ASR/transducer_lstm/decoder.py +++ b/egs/librispeech/ASR/transducer_lstm/decoder.py @@ -14,25 +14,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple - import torch import torch.nn as nn +import torch.nn.functional as F -# TODO(fangjun): Support switching between LSTM and GRU class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + def __init__( self, vocab_size: int, embedding_dim: int, blank_id: int, - sos_id: int, - num_layers: int, - hidden_dim: int, - output_dim: int, - embedding_dropout: float = 0.0, - rnn_dropout: float = 0.0, + context_size: int, ): """ Args: @@ -42,18 +47,9 @@ class Decoder(nn.Module): Dimension of the input embedding. blank_id: The ID of the blank symbol. - sos_id: - The ID of the SOS symbol. - num_layers: - Number of LSTM layers. - hidden_dim: - Hidden dimension of LSTM layers. - output_dim: - Output dimension of the decoder. - embedding_dropout: - Dropout rate for the embedding layer. - rnn_dropout: - Dropout for LSTM layers. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() self.embedding = nn.Embedding( @@ -61,41 +57,42 @@ class Decoder(nn.Module): embedding_dim=embedding_dim, padding_idx=blank_id, ) - self.embedding_dropout = nn.Dropout(embedding_dropout) - # TODO(fangjun): Use layer normalized LSTM - self.rnn = nn.LSTM( - input_size=embedding_dim, - hidden_size=hidden_dim, - num_layers=num_layers, - batch_first=True, - dropout=rnn_dropout, - ) self.blank_id = blank_id - self.sos_id = sos_id - self.output_linear = nn.Linear(hidden_dim, output_dim) - def forward( - self, - y: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + assert context_size >= 1, context_size + self.context_size = context_size + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=embedding_dim, + out_channels=embedding_dim, + kernel_size=context_size, + padding=0, + groups=embedding_dim, + bias=False, + ) + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ Args: y: - A 2-D tensor of shape (N, U) with BOS prepended. - states: - A tuple of two tensors containing the states information of - LSTM layers in this decoder. + A 2-D tensor of shape (N, U). + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. Returns: - Return a tuple containing: - - - rnn_output, a tensor of shape (N, U, C) - - (h, c), containing the state information for LSTM layers. - Both are of shape (num_layers, N, C) + Return a tensor of shape (N, U, embedding_dim). """ - embeding_out = self.embedding(y) - embeding_out = self.embedding_dropout(embeding_out) - rnn_out, (h, c) = self.rnn(embeding_out, states) - out = self.output_linear(rnn_out) - - return out, (h, c) + embedding_out = self.embedding(y) + if self.context_size > 1: + embedding_out = embedding_out.permute(0, 2, 1) + if need_pad is True: + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + return embedding_out diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py index 860a84bb1..50c31275c 100644 --- a/egs/librispeech/ASR/transducer_lstm/encoder.py +++ b/egs/librispeech/ASR/transducer_lstm/encoder.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Tuple import torch @@ -87,7 +88,9 @@ class LstmEncoder(EncoderInterface): x = self.encoder_embed(x) # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = ((x_lens - 1) // 2 - 1) // 2 assert x.size(1) == lengths.max().item(), ( x.size(1), lengths.max(), diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index 31843b60e..02c5eabb0 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -49,7 +49,7 @@ class Transducer(nn.Module): decoder: It is the prediction network in the paper. Its input shape is (N, U) and its output shape is (N, U, C). It should contain - two attributes: `blank_id` and `sos_id`. + one attribute: `blank_id`. joiner: It has two inputs with shapes: (N, T, C) and (N, U, C). Its output shape is (N, T, U, C). Note that its output contains @@ -58,7 +58,6 @@ class Transducer(nn.Module): super().__init__() assert isinstance(encoder, EncoderInterface) assert hasattr(decoder, "blank_id") - assert hasattr(decoder, "sos_id") self.encoder = encoder self.decoder = decoder @@ -97,13 +96,12 @@ class Transducer(nn.Module): y_lens = row_splits[1:] - row_splits[:-1] blank_id = self.decoder.blank_id - sos_id = self.decoder.sos_id - sos_y = add_sos(y, sos_id=sos_id) + sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y_padded.to(torch.int64) - decoder_out, _ = self.decoder(sos_y_padded) + decoder_out = self.decoder(sos_y_padded) logits = self.joiner(encoder_out, decoder_out) diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index eef4d3430..7f4dc32cf 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -139,6 +139,14 @@ def get_parser(): help="The seed for random generators intended for reproducibility", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + return parser @@ -235,15 +243,12 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.decoder_embedding_dim, + embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, - sos_id=params.sos_id, - num_layers=params.num_decoder_layers, - hidden_dim=params.decoder_hidden_dim, - output_dim=params.encoder_out_dim, + context_size=params.context_size, ) return decoder @@ -400,9 +405,11 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -580,9 +587,8 @@ def run(rank, world_size, args): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params)