diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 49b1308b0..0e3b0f197 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -483,8 +483,9 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py + # and is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 8c728fdc5..f4355e8a0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -37,6 +37,7 @@ class Decoder(nn.Module): vocab_size: int, embedding_dim: int, blank_id: int, + unk_id: int, context_size: int, ): """ @@ -47,6 +48,8 @@ class Decoder(nn.Module): Dimension of the input embedding. blank_id: The ID of the blank symbol. + unk_id: + The ID of the unk symbol. context_size: Number of previous words to use to predict the next word. 1 means bigram; 2 means trigram. n means (n+1)-gram. @@ -58,6 +61,7 @@ class Decoder(nn.Module): padding_idx=blank_id, ) self.blank_id = blank_id + self.unk_id = unk_id assert context_size >= 1, context_size self.context_size = context_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index e743106ec..f0ea12d62 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -319,6 +319,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: vocab_size=params.vocab_size, embedding_dim=params.embedding_dim, blank_id=params.blank_id, + unk_id=params.unk_id, context_size=params.context_size, ) return decoder @@ -756,8 +757,9 @@ def run(rank, world_size, args): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py + # and is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/tedlium3/ASR/RESULTS.md b/egs/tedlium3/ASR/RESULTS.md index 79fb320f9..beeeb047b 100644 --- a/egs/tedlium3/ASR/RESULTS.md +++ b/egs/tedlium3/ASR/RESULTS.md @@ -12,7 +12,7 @@ The WERs are |------------------------------------|------------|------------|------------------------------------------| | greedy search | 7.27 | 6.69 | --epoch 29, --avg 13, --max-duration 100 | | beam search (beam size 4) | 6.70 | 6.04 | --epoch 29, --avg 13, --max-duration 100 | -| modified beam search (beam size 4) | 6.72 | 6.12 | --epoch 29, --avg 13, --max-duration 100 | +| modified beam search (beam size 4) | 6.77 | 6.12 | --epoch 29, --avg 13, --max-duration 100 | | fast beam search (set as default) | 7.14 | 6.50 | --epoch 29, --avg 13, --max-duration 1500| The training command for reproducing is given below: diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py index 061d09e2f..0ae001d3f 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py @@ -1,5 +1,5 @@ # Copyright 2020 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -369,13 +369,158 @@ class HypothesisList(object): return ", ".join(s) +def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + beam: + Number of active paths during the beam search. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id + context_size = model.decoder.context_size + device = model.device + B = [HypothesisList() for _ in range(batch_size)] + for i in range(batch_size): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = _get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + # decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc") + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id and new_token != unk_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + ans = [h.ys[context_size:] for h in best_hyps] + + return ans + + +def _deprecated_modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + + Args: model: An instance of `Transducer`. diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py index 57901b0c6..fd8d2dd0e 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py @@ -74,13 +74,9 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info from icefall.utils import ( AttributeDict, setup_logger, @@ -182,7 +178,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) @@ -190,73 +186,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - # parameters for decoder - "embedding_dim": 512, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict): - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.vocab_size, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict): - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, - blank_id=params.blank_id, - unk_id=params.unk_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict): - joiner = Joiner( - input_dim=params.vocab_size, - inner_dim=params.embedding_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict): - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -329,6 +258,14 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -348,12 +285,6 @@ def decode_one_batch( encoder_out=encoder_out_i, beam=params.beam_size, ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -593,8 +524,5 @@ def main(): logging.info("Done!") -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py deleted file mode 100644 index 8c7a269c3..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - embedding_dim: int, - blank_id: int, - unk_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - embedding_dim: - Dimension of the input embedding. - blank_id: - The ID of the blank symbol. - unk_id: - The ID of the unk symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim, - padding_idx=blank_id, - ) - self.blank_id = blank_id - self.unk_id = unk_id - assert context_size >= 1, context_size - self.context_size = context_size - self.vocab_size = vocab_size - if context_size > 1: - self.conv = nn.Conv1d( - in_channels=embedding_dim, - out_channels=embedding_dim, - kernel_size=context_size, - padding=0, - groups=embedding_dim, - bias=False, - ) - self.output_linear = nn.Linear(embedding_dim, vocab_size) - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, embedding_dim). - """ - embedding_out = self.embedding(y) - if self.context_size > 1: - embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embedding_out.size(-1) == self.context_size - embedding_out = self.conv(embedding_out) - embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = self.output_linear(F.relu(embedding_out)) - return embedding_out diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py new file mode 120000 index 000000000..206384eaa --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless/decoder.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py index 9e87f8c8a..1e6edbb99 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang -# Mingshuang Luo) +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -50,15 +50,10 @@ from pathlib import Path import sentencepiece as spm import torch -import torch.nn as nn -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import str2bool def get_parser(): @@ -69,7 +64,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=20, + default=30, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) @@ -77,7 +72,7 @@ def get_parser(): parser.add_argument( "--avg", type=int, - default=10, + default=13, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -118,73 +113,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - # parameters for decoder - "embedding_dim": 512, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - unk_id=params.unk_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - inner_dim=params.embedding_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir)