From 089b8178f079302667942939126827d99502f8e1 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 18 Mar 2022 10:45:47 +0800 Subject: [PATCH] Fix comments --- .../beam_search.py | 21 ++++++++++++------- .../ASR/pruned_transducer_stateless/decode.py | 17 ++++++++------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index e5fad74c4..651854999 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -26,23 +26,26 @@ from icefall.utils import get_texts def fast_beam_search( - decoding_graph: k2.Fsa, model: Transducer, + decoding_graph: k2.Fsa, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, beam: float, max_states: int, max_contexts: int, -) -> List[int]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. Args: model: An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. decoding_graph: Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. beam: Beam value, similar to the beam used in Kaldi.. max_states: @@ -66,15 +69,17 @@ def fast_beam_search( max_contexts=max_contexts, max_states=max_states, ) - indivisual_streams = [] + individual_streams = [] for i in range(B): - indivisual_streams.append(k2.RnntDecodingStream(decoding_graph)) - decoding_streams = k2.RnntDecodingStreams(indivisual_streams, config) + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) for t in range(T): # shape is a RaggedShape of shape (B, context) # contexts is a Tensor of shape (shape.NumElements(), context_size) shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) decoder_out = model.decoder(contexts, need_pad=False) # current_encoder_out is of shape @@ -90,7 +95,7 @@ def fast_beam_search( logits = logits.squeeze(1).squeeze(1) log_probs = logits.log_softmax(dim=-1) decoding_streams.advance(log_probs) - decoding_streams.terminate_and_flush_to_atreams() + decoding_streams.terminate_and_flush_to_streams() lattice = decoding_streams.format_output(encoder_out_lens.tolist()) best_path = one_best_decoding(lattice) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index b29f69db0..2393c3a3f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -60,7 +60,7 @@ import argparse import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import k2 import sentencepiece as spm @@ -135,16 +135,19 @@ def get_parser(): "--beam-size", type=int, default=4, - help="""Used only when --decoding-method is - beam_search or modified_beam_search""", + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", ) parser.add_argument( "--beam", type=float, default=4, - help="""Used only when --decoding-method is - fast_beam_search""", + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", ) parser.add_argument( @@ -185,8 +188,8 @@ def decode_one_batch( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, - decoding_graph: k2.Fsa, batch: dict, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -293,7 +296,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, - decoding_graph: k2.Fsa, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset.