Fix comments

This commit is contained in:
pkufool 2022-03-18 10:45:47 +08:00
parent dbfe8fbb1a
commit 089b8178f0
2 changed files with 23 additions and 15 deletions

View File

@ -26,23 +26,26 @@ from icefall.utils import get_texts
def fast_beam_search( def fast_beam_search(
decoding_graph: k2.Fsa,
model: Transducer, model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
beam: float, beam: float,
max_states: int, max_states: int,
max_contexts: int, max_contexts: int,
) -> List[int]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
decoding_graph: decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG. Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam: beam:
Beam value, similar to the beam used in Kaldi.. Beam value, similar to the beam used in Kaldi..
max_states: max_states:
@ -66,15 +69,17 @@ def fast_beam_search(
max_contexts=max_contexts, max_contexts=max_contexts,
max_states=max_states, max_states=max_states,
) )
indivisual_streams = [] individual_streams = []
for i in range(B): for i in range(B):
indivisual_streams.append(k2.RnntDecodingStream(decoding_graph)) individual_streams.append(k2.RnntDecodingStream(decoding_graph))
decoding_streams = k2.RnntDecodingStreams(indivisual_streams, config) decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
for t in range(T): for t in range(T):
# shape is a RaggedShape of shape (B, context) # shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size) # contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts() shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False) decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape # current_encoder_out is of shape
@ -90,7 +95,7 @@ def fast_beam_search(
logits = logits.squeeze(1).squeeze(1) logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs) decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_atreams() decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist()) lattice = decoding_streams.format_output(encoder_out_lens.tolist())
best_path = one_best_decoding(lattice) best_path = one_best_decoding(lattice)

View File

@ -60,7 +60,7 @@ import argparse
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
@ -135,16 +135,19 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="""Used only when --decoding-method is help="""An interger indicating how many candidates we will keep for each
beam_search or modified_beam_search""", frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
) )
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
default=4, default=4,
help="""Used only when --decoding-method is help="""A floating point value to calculate the cutoff score during beam
fast_beam_search""", search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -185,8 +188,8 @@ def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
decoding_graph: k2.Fsa,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -293,7 +296,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
decoding_graph: k2.Fsa, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.